StefanG2002 commited on
Commit
2a0a9e8
·
verified ·
1 Parent(s): ca017d0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -1
main.py CHANGED
@@ -26,7 +26,26 @@ class Item(BaseModel):
26
  top_p: float = 0.15
27
  repetition_penalty: float = 1.0
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @app.post("/generate/")
31
  async def generate_text(item: Item):
32
- return {"response": item.prompt}
 
26
  top_p: float = 0.15
27
  repetition_penalty: float = 1.0
28
 
29
+ def generate(item: Item):
30
+ temperature = float(item.temperature)
31
+ if temperature < 1e-2:
32
+ temperature = 1e-2
33
+ top_p = float(item.top_p)
34
+
35
+ device = "cpu"
36
+
37
+ prompt = item.prompt
38
+
39
+ encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
40
+
41
+ model_inputs = encodeds.to(device)
42
+
43
+
44
+ generated_ids = model.generate(**model_inputs, item.max_new_tokens, do_sample=True, pad_token_id=tokenizer.eos_token_id, temperature=temperature)
45
+ decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
+ return decoded
47
+
48
 
49
  @app.post("/generate/")
50
  async def generate_text(item: Item):
51
+ return {"response": generate(item)}