File size: 3,061 Bytes
dbbf91e
fe79d9c
 
f7049dd
dbbf91e
fe79d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbbf91e
fe79d9c
f7049dd
fe79d9c
dbbf91e
 
fe79d9c
dbbf91e
fe79d9c
 
 
 
 
321d5fe
fe79d9c
 
dbbf91e
 
fe79d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
dbbf91e
 
 
 
fe79d9c
 
 
 
 
 
 
 
808de23
dbbf91e
 
808de23
 
 
 
 
dbbf91e
808de23
 
 
 
f7049dd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import BertTokenizer, BertModel, pipeline
import torch as t
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Configure CORS: In production, you might restrict allowed origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static files (frontend) so that visiting "/" serves index.html
# The directory path "../frontend" works because when running in Docker,
# our working directory is set to /app, and the frontend folder is at /app/frontend.
app.mount("/static", StaticFiles(directory="frontend", html=True), name="static")

# Load tokenizer and BERT model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
try:
    model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
except Exception as e:
    logger.error(f"Model loading failed: {e}")
    raise

@app.post("/process")
async def process_text(text: str = Body(..., embed=True)):
    """
    Process the input text:
      - Tokenizes the text using BERT's tokenizer
      - Runs the BERT model to obtain attentions (bidirectional)
      - Returns the tokens and attention values (rounded to 2 decimals)
    """
    try:
        logger.info(f"Received text: {text}")
        # Tokenize input text (truncating if needed)
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

        # Run the model without gradient computation (inference mode)
        with t.no_grad():
            outputs = model(**inputs)
        attentions = outputs.attentions  # Tuple of attention tensors for each layer

        decimals = 2
        # Convert attention tensors to lists with rounded decimals
        attn_series = t.round(
            t.tensor([layer_attention.tolist() for layer_attention in attentions], dtype=t.double)
            .squeeze(), decimals=decimals
        ).detach().cpu().tolist()

        return {
            "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
            "attention": attn_series
        }
    except Exception as e:
        logger.error(f"Error processing text: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# Initialize the text generation pipeline (unchanged)
pipe = pipeline("text2text-generation", model="google/flan-t5-small")

@app.get("/generate")
def generate(text: str):
    """
    Using the text2text-generation pipeline from `transformers`, generate text
    from the given input text. The model used is `google/flan-t5-small`.
    """
    # Use the pipeline to generate text from the given input text
    output = pipe(text)
    # Return the generated text in a JSON response
    return {"output": output[0]["generated_text"]}

@app.get("/")
async def read_index():
    return FileResponse("frontend/index.html")