Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| from app.model_loader import load_model | |
| import torch | |
| app = FastAPI() | |
| model, tokenizer = load_model() | |
| async def predict(request: Request): | |
| data = await request.json() | |
| raw_abstract = data.get("input", "") | |
| # Get the last sentence (or few words) of the abstract | |
| import re | |
| sentences = re.split(r'(?<=[.!?]) +', raw_abstract.strip()) | |
| abstract_tail = sentences[-1] if len(sentences) > 1 else raw_abstract | |
| # Construct the prompt | |
| prompt = ( | |
| f"This neuroscience abstract ends as follows:\n" | |
| f"\"{abstract_tail}\"\n\n" | |
| f"Complete the next sentence logically:" | |
| ) | |
| # Tokenize and generate | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=20, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95 | |
| ) | |
| # Decode and trim | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| continuation = response[len(prompt):].strip() | |
| return JSONResponse(content={"output": continuation}) | |