Paar, F. (Ferdinand)
welcome bert
dbbf91e
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import BertTokenizer, BertModel, pipeline
import torch as t
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Configure CORS: In production, you might restrict allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files (frontend) so that visiting "/" serves index.html
# The directory path "../frontend" works because when running in Docker,
# our working directory is set to /app, and the frontend folder is at /app/frontend.
app.mount("/static", StaticFiles(directory="frontend", html=True), name="static")
# Load tokenizer and BERT model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
try:
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise
@app.post("/process")
async def process_text(text: str = Body(..., embed=True)):
"""
Process the input text:
- Tokenizes the text using BERT's tokenizer
- Runs the BERT model to obtain attentions (bidirectional)
- Returns the tokens and attention values (rounded to 2 decimals)
"""
try:
logger.info(f"Received text: {text}")
# Tokenize input text (truncating if needed)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
# Run the model without gradient computation (inference mode)
with t.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # Tuple of attention tensors for each layer
decimals = 2
# Convert attention tensors to lists with rounded decimals
attn_series = t.round(
t.tensor([layer_attention.tolist() for layer_attention in attentions], dtype=t.double)
.squeeze(), decimals=decimals
).detach().cpu().tolist()
return {
"tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
"attention": attn_series
}
except Exception as e:
logger.error(f"Error processing text: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Initialize the text generation pipeline (unchanged)
pipe = pipeline("text2text-generation", model="google/flan-t5-small")
@app.get("/generate")
def generate(text: str):
"""
Using the text2text-generation pipeline from `transformers`, generate text
from the given input text. The model used is `google/flan-t5-small`.
"""
# Use the pipeline to generate text from the given input text
output = pipe(text)
# Return the generated text in a JSON response
return {"output": output[0]["generated_text"]}
@app.get("/")
async def read_index():
return FileResponse("frontend/index.html")