AndaiMD commited on
Commit
d1e903b
·
1 Parent(s): c2ebdd7
Files changed (1) hide show
  1. app/main.py +15 -11
app/main.py CHANGED
@@ -10,24 +10,28 @@ model, tokenizer = load_model()
10
  async def predict(request: Request):
11
  data = await request.json()
12
  input_text = data.get("input", "")
13
-
14
- # Tokenize and move to model device
15
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
16
-
17
- # Generate next 15 tokens
 
 
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  **inputs,
21
- max_new_tokens=15,
22
- do_sample=True, # Optional: adds randomness
23
- temperature=0.8, # Optional: more natural output
 
 
24
  pad_token_id=tokenizer.eos_token_id
25
  )
26
 
27
- # Decode only new part of generation
28
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
 
30
- # Extract the continuation only (optional but useful)
31
- continuation = generated_text[len(input_text):].strip()
32
 
33
  return JSONResponse(content={"output": continuation})
 
10
  async def predict(request: Request):
11
  data = await request.json()
12
  input_text = data.get("input", "")
13
+
14
+ # Extract last 5 words
15
+ last_5_words = " ".join(input_text.strip().split()[-5:])
16
+
17
+ # Tokenize and generate continuation
18
+ inputs = tokenizer(last_5_words, return_tensors="pt").to(model.device)
19
+
20
  with torch.no_grad():
21
  outputs = model.generate(
22
  **inputs,
23
+ max_new_tokens=20,
24
+ do_sample=True,
25
+ temperature=0.8,
26
+ top_k=50,
27
+ top_p=0.95,
28
  pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
+ # Decode generated text
32
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
 
34
+ # Remove the prompt portion to isolate generated words
35
+ continuation = generated_text[len(last_5_words):].strip()
36
 
37
  return JSONResponse(content={"output": continuation})