from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from app.model_loader import load_model import torch app = FastAPI() model, tokenizer = load_model() @app.post("/predict") async def predict(request: Request): data = await request.json() raw_abstract = data.get("input", "") # Get the last sentence (or few words) of the abstract import re sentences = re.split(r'(?<=[.!?]) +', raw_abstract.strip()) abstract_tail = sentences[-1] if len(sentences) > 1 else raw_abstract # Construct the prompt prompt = ( f"This neuroscience abstract ends as follows:\n" f"\"{abstract_tail}\"\n\n" f"Complete the next sentence logically:" ) # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=50, top_p=0.95 ) # Decode and trim response = tokenizer.decode(outputs[0], skip_special_tokens=True) continuation = response[len(prompt):].strip() return JSONResponse(content={"output": continuation})