File size: 1,225 Bytes
c2ebdd7
2b8db5e
 
 
 
 
 
 
 
261a286
 
f04940c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b8db5e
c2ebdd7
 
d1e903b
 
f04940c
d1e903b
f04940c
c2ebdd7
 
f04940c
 
 
c2ebdd7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from app.model_loader import load_model
import torch

app = FastAPI()
model, tokenizer = load_model()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    raw_abstract = data.get("input", "")

    # Get the last sentence (or few words) of the abstract
    import re
    sentences = re.split(r'(?<=[.!?]) +', raw_abstract.strip())
    abstract_tail = sentences[-1] if len(sentences) > 1 else raw_abstract

    # Construct the prompt
    prompt = (
        f"This neuroscience abstract ends as follows:\n"
        f"\"{abstract_tail}\"\n\n"
        f"Complete the next sentence logically:"
    )

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95
        )

    # Decode and trim
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    continuation = response[len(prompt):].strip()

    return JSONResponse(content={"output": continuation})