Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from transformers import BertTokenizer, BertModel, pipeline | |
| import torch as t | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Configure CORS: In production, you might restrict allowed origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files (frontend) so that visiting "/" serves index.html | |
| # The directory path "../frontend" works because when running in Docker, | |
| # our working directory is set to /app, and the frontend folder is at /app/frontend. | |
| app.mount("/static", StaticFiles(directory="frontend", html=True), name="static") | |
| # Load tokenizer and BERT model | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| try: | |
| model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| raise | |
| async def process_text(text: str = Body(..., embed=True)): | |
| """ | |
| Process the input text: | |
| - Tokenizes the text using BERT's tokenizer | |
| - Runs the BERT model to obtain attentions (bidirectional) | |
| - Returns the tokens and attention values (rounded to 2 decimals) | |
| """ | |
| try: | |
| logger.info(f"Received text: {text}") | |
| # Tokenize input text (truncating if needed) | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| # Run the model without gradient computation (inference mode) | |
| with t.no_grad(): | |
| outputs = model(**inputs) | |
| attentions = outputs.attentions # Tuple of attention tensors for each layer | |
| decimals = 2 | |
| # Convert attention tensors to lists with rounded decimals | |
| attn_series = t.round( | |
| t.tensor([layer_attention.tolist() for layer_attention in attentions], dtype=t.double) | |
| .squeeze(), decimals=decimals | |
| ).detach().cpu().tolist() | |
| return { | |
| "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]), | |
| "attention": attn_series | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing text: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Initialize the text generation pipeline (unchanged) | |
| pipe = pipeline("text2text-generation", model="google/flan-t5-small") | |
| def generate(text: str): | |
| """ | |
| Using the text2text-generation pipeline from `transformers`, generate text | |
| from the given input text. The model used is `google/flan-t5-small`. | |
| """ | |
| # Use the pipeline to generate text from the given input text | |
| output = pipe(text) | |
| # Return the generated text in a JSON response | |
| return {"output": output[0]["generated_text"]} | |
| async def read_index(): | |
| return FileResponse("frontend/index.html") |