MindLabUnimib commited on
Commit
c695fab
·
verified ·
1 Parent(s): ead4659

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -52,15 +52,17 @@ import torch
52
  import transformers
53
 
54
  from transformers.utils.import_utils import is_flash_attn_2_available
 
55
 
56
  chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
57
  cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
58
 
59
  model = transformers.pipeline(
60
  model=chat_model_name,
61
- model_kwargs={"dtype": torch.bfloat16, "padding_side": "left"} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
62
  device=device,
63
  )
 
64
 
65
  classifier = transformers.pipeline(
66
  model=cls_model_name,
@@ -78,7 +80,7 @@ def generate(submission: list[dict[str, str]], team_id: str) -> list[dict[str, s
78
  prompts = [s["prompt"] for s in submission]
79
 
80
  messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
81
- outputs = model(messages, do_sample=False, temperature=None, max_new_tokens=512, repetition_penalty=1.1, batch_size=25)
82
  responses = [output[0]["generated_text"][-1]["content"] for output in outputs]
83
 
84
  predictions = classifier([{"text": p, "text_pair": r} for p, r in zip(prompts, responses)], return_all_scores=True, batch_size=25)
 
52
  import transformers
53
 
54
  from transformers.utils.import_utils import is_flash_attn_2_available
55
+ print("is_flash_attn_2_available: ", is_flash_attn_2_available())
56
 
57
  chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
58
  cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
59
 
60
  model = transformers.pipeline(
61
  model=chat_model_name,
62
+ model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
63
  device=device,
64
  )
65
+ model.tokenizer.padding_side = "left"
66
 
67
  classifier = transformers.pipeline(
68
  model=cls_model_name,
 
80
  prompts = [s["prompt"] for s in submission]
81
 
82
  messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
83
+ outputs = model(messages, do_sample=False, temperature=None, max_new_tokens=512, repetition_penalty=1.1, batch_size=25, padding_side="left")
84
  responses = [output[0]["generated_text"][-1]["content"] for output in outputs]
85
 
86
  predictions = classifier([{"text": p, "text_pair": r} for p, r in zip(prompts, responses)], return_all_scores=True, batch_size=25)