MindLabUnimib commited on
Commit
d28e427
·
1 Parent(s): 9af0597

feat: handle multiple prompts

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -15,29 +15,27 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  classifier = pipeline("text-classification", model="saiteki-kai/QA-DeBERTa-v3-large")
16
 
17
  @spaces.GPU(duration=60)
18
- def generate(message):
19
- messages = [
20
- {"role": "user", "content": message}
21
- ]
22
- text = tokenizer.apply_chat_template(
23
  messages,
24
  tokenize=False,
25
  add_generation_prompt=True
26
  )
27
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
28
  generated_ids = model.generate(
29
  **model_inputs,
30
  do_sample=False,
31
  temperature=0,
32
  repetition_penalty=1.0,
33
- max_new_tokens=512,
34
  )
35
  generated_ids = [
36
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
37
  ]
38
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
40
- return response, classifier(text + "[SEP]" + response)
41
 
42
 
43
  demo = gr.Interface(fn=generate, inputs=gr.Text(), outputs=gr.Text())
 
15
  classifier = pipeline("text-classification", model="saiteki-kai/QA-DeBERTa-v3-large")
16
 
17
  @spaces.GPU(duration=60)
18
+ def generate(prompts):
19
+ messages = [[{"role": "user", "content": message}] for message in prompts]
20
+
21
+ texts = tokenizer.apply_chat_template(
 
22
  messages,
23
  tokenize=False,
24
  add_generation_prompt=True
25
  )
26
+ model_inputs = tokenizer(texts, padding=True, max_new_tokens=512, return_tensors="pt").to(model.device)
27
  generated_ids = model.generate(
28
  **model_inputs,
29
  do_sample=False,
30
  temperature=0,
31
  repetition_penalty=1.0,
 
32
  )
33
  generated_ids = [
34
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
35
  ]
36
+ responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
37
 
38
+ return responses, classifier([text + "[SEP]" + response for text, response in zip(texts, responses)])
39
 
40
 
41
  demo = gr.Interface(fn=generate, inputs=gr.Text(), outputs=gr.Text())