Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,15 +52,17 @@ import torch
|
|
| 52 |
import transformers
|
| 53 |
|
| 54 |
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
|
|
| 55 |
|
| 56 |
chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
|
| 57 |
cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
|
| 58 |
|
| 59 |
model = transformers.pipeline(
|
| 60 |
model=chat_model_name,
|
| 61 |
-
model_kwargs={"dtype": torch.bfloat16
|
| 62 |
device=device,
|
| 63 |
)
|
|
|
|
| 64 |
|
| 65 |
classifier = transformers.pipeline(
|
| 66 |
model=cls_model_name,
|
|
@@ -78,7 +80,7 @@ def generate(submission: list[dict[str, str]], team_id: str) -> list[dict[str, s
|
|
| 78 |
prompts = [s["prompt"] for s in submission]
|
| 79 |
|
| 80 |
messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
|
| 81 |
-
outputs = model(messages, do_sample=False, temperature=None, max_new_tokens=512, repetition_penalty=1.1, batch_size=25)
|
| 82 |
responses = [output[0]["generated_text"][-1]["content"] for output in outputs]
|
| 83 |
|
| 84 |
predictions = classifier([{"text": p, "text_pair": r} for p, r in zip(prompts, responses)], return_all_scores=True, batch_size=25)
|
|
|
|
| 52 |
import transformers
|
| 53 |
|
| 54 |
from transformers.utils.import_utils import is_flash_attn_2_available
|
| 55 |
+
print("is_flash_attn_2_available: ", is_flash_attn_2_available())
|
| 56 |
|
| 57 |
chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
|
| 58 |
cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
|
| 59 |
|
| 60 |
model = transformers.pipeline(
|
| 61 |
model=chat_model_name,
|
| 62 |
+
model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
|
| 63 |
device=device,
|
| 64 |
)
|
| 65 |
+
model.tokenizer.padding_side = "left"
|
| 66 |
|
| 67 |
classifier = transformers.pipeline(
|
| 68 |
model=cls_model_name,
|
|
|
|
| 80 |
prompts = [s["prompt"] for s in submission]
|
| 81 |
|
| 82 |
messages = [[{"role": "user", "content": prompt}] for prompt in prompts]
|
| 83 |
+
outputs = model(messages, do_sample=False, temperature=None, max_new_tokens=512, repetition_penalty=1.1, batch_size=25, padding_side="left")
|
| 84 |
responses = [output[0]["generated_text"][-1]["content"] for output in outputs]
|
| 85 |
|
| 86 |
predictions = classifier([{"text": p, "text_pair": r} for p, r in zip(prompts, responses)], return_all_scores=True, batch_size=25)
|