frammartina's picture
Create app.py
ec548e9 verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_PATH = "LSTM__0.9170.pt.pt"
MODEL_URL = "https://drive.google.com/uc?id=133F-sRp_mCGOo73t1ieSnbk5fSxPFENT"
if not os.path.exists(MODEL_PATH):
import gdown
print("Scaricamento dei pesi dal Google Drive...")
gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1", use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(
"dmis-lab/biobert-base-cased-v1.1",
num_labels=2
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
app = FastAPI()
class Query(BaseModel):
question: str
context: str
long_answer: str
@app.post("/chat")
def get_response(query: Query):
text = query.question + " " + query.context + " " + query.long_answer
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
answer = torch.argmax(outputs.logits, dim=-1).item()
result = "Yes" if answer == 1 else "No"
return {"answer": result}