arasaltan commited on
Commit
5103e4b
Β·
verified Β·
1 Parent(s): 4112b33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -24,19 +24,20 @@ model = PeftModel.from_pretrained(model, LORA_PATH)
24
  model.eval()
25
 
26
 
27
- def chat(prompt, max_new_tokens=256, temperature=0.7):
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
 
30
  with torch.no_grad():
31
  output = model.generate(
32
  **inputs,
33
- max_new_tokens=max_new_tokens,
34
- temperature=temperature,
35
- do_sample=True,
36
  eos_token_id=tokenizer.eos_token_id
37
  )
38
 
39
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
40
 
41
 
42
  # Gradio UI
 
24
  model.eval()
25
 
26
 
27
+ def chat(prompt):
28
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
29
 
30
  with torch.no_grad():
31
  output = model.generate(
32
  **inputs,
33
+ max_new_tokens=256,
34
+ do_sample=False,
 
35
  eos_token_id=tokenizer.eos_token_id
36
  )
37
 
38
+ generated = output[0][inputs["input_ids"].shape[-1]:]
39
+ return tokenizer.decode(generated, skip_special_tokens=True)
40
+
41
 
42
 
43
  # Gradio UI