brainbench / app /main.py
AndaiMD's picture
predict
f04940c
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from app.model_loader import load_model
import torch
app = FastAPI()
model, tokenizer = load_model()
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
raw_abstract = data.get("input", "")
# Get the last sentence (or few words) of the abstract
import re
sentences = re.split(r'(?<=[.!?]) +', raw_abstract.strip())
abstract_tail = sentences[-1] if len(sentences) > 1 else raw_abstract
# Construct the prompt
prompt = (
f"This neuroscience abstract ends as follows:\n"
f"\"{abstract_tail}\"\n\n"
f"Complete the next sentence logically:"
)
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode and trim
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
continuation = response[len(prompt):].strip()
return JSONResponse(content={"output": continuation})