AndaiMD commited on
Commit
c2ebdd7
·
1 Parent(s): 5dfbe24
Files changed (1) hide show
  1. app/main.py +21 -6
app/main.py CHANGED
@@ -1,5 +1,4 @@
1
-
2
- from fastapi import FastAPI, Request, Form
3
  from fastapi.responses import JSONResponse
4
  from app.model_loader import load_model
5
  import torch
@@ -11,8 +10,24 @@ model, tokenizer = load_model()
11
  async def predict(request: Request):
12
  data = await request.json()
13
  input_text = data.get("input", "")
14
- inputs = tokenizer(input_text, return_tensors="pt")
 
 
 
 
15
  with torch.no_grad():
16
- outputs = model.generate(**inputs, max_new_tokens=50)
17
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- return JSONResponse(content={"output": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
 
2
  from fastapi.responses import JSONResponse
3
  from app.model_loader import load_model
4
  import torch
 
10
  async def predict(request: Request):
11
  data = await request.json()
12
  input_text = data.get("input", "")
13
+
14
+ # Tokenize and move to model device
15
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
16
+
17
+ # Generate next 15 tokens
18
  with torch.no_grad():
19
+ outputs = model.generate(
20
+ **inputs,
21
+ max_new_tokens=15,
22
+ do_sample=True, # Optional: adds randomness
23
+ temperature=0.8, # Optional: more natural output
24
+ pad_token_id=tokenizer.eos_token_id
25
+ )
26
+
27
+ # Decode only new part of generation
28
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+
30
+ # Extract the continuation only (optional but useful)
31
+ continuation = generated_text[len(input_text):].strip()
32
+
33
+ return JSONResponse(content={"output": continuation})