from fastapi import FastAPI, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from transformers import BertTokenizer, BertModel, pipeline import torch as t import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Configure CORS: In production, you might restrict allowed origins app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Mount static files (frontend) so that visiting "/" serves index.html # The directory path "../frontend" works because when running in Docker, # our working directory is set to /app, and the frontend folder is at /app/frontend. app.mount("/static", StaticFiles(directory="frontend", html=True), name="static") # Load tokenizer and BERT model tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') try: model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) except Exception as e: logger.error(f"Model loading failed: {e}") raise @app.post("/process") async def process_text(text: str = Body(..., embed=True)): """ Process the input text: - Tokenizes the text using BERT's tokenizer - Runs the BERT model to obtain attentions (bidirectional) - Returns the tokens and attention values (rounded to 2 decimals) """ try: logger.info(f"Received text: {text}") # Tokenize input text (truncating if needed) inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) # Run the model without gradient computation (inference mode) with t.no_grad(): outputs = model(**inputs) attentions = outputs.attentions # Tuple of attention tensors for each layer decimals = 2 # Convert attention tensors to lists with rounded decimals attn_series = t.round( t.tensor([layer_attention.tolist() for layer_attention in attentions], dtype=t.double) .squeeze(), decimals=decimals ).detach().cpu().tolist() return { "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), "attention": attn_series } except Exception as e: logger.error(f"Error processing text: {e}") raise HTTPException(status_code=500, detail=str(e)) # Initialize the text generation pipeline (unchanged) pipe = pipeline("text2text-generation", model="google/flan-t5-small") @app.get("/generate") def generate(text: str): """ Using the text2text-generation pipeline from `transformers`, generate text from the given input text. The model used is `google/flan-t5-small`. """ # Use the pipeline to generate text from the given input text output = pipe(text) # Return the generated text in a JSON response return {"output": output[0]["generated_text"]} @app.get("/") async def read_index(): return FileResponse("frontend/index.html")