SreyanG-NVIDIA commited on
Commit
3e73e16
·
verified ·
1 Parent(s): 678e362

Add repetition_penalty

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -448,7 +448,7 @@ def infer(audio_path, youtube_url, prompt_text):
448
  return_dict=True,
449
  ).to(model.device)
450
 
451
- gen_ids = model.generate(**batch, max_new_tokens=4096)
452
  inp_len = batch["input_ids"].shape[1]
453
  new_tokens = gen_ids[:, inp_len:]
454
  texts = processor.batch_decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
 
448
  return_dict=True,
449
  ).to(model.device)
450
 
451
+ gen_ids = model.generate(**batch, max_new_tokens=4096, repetition_penalty=1.2)
452
  inp_len = batch["input_ids"].shape[1]
453
  new_tokens = gen_ids[:, inp_len:]
454
  texts = processor.batch_decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)