Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit ·
298d825
1
Parent(s): d4bf693
minor update for new models; postprocessing for md format
Browse files- requirements.txt +3 -1
- src/backend/model_operations.py +16 -9
- src/backend/util.py +14 -0
requirements.txt
CHANGED
|
@@ -21,4 +21,6 @@ anthropic
|
|
| 21 |
openai
|
| 22 |
cohere
|
| 23 |
mistralai
|
| 24 |
-
peft
|
|
|
|
|
|
|
|
|
| 21 |
openai
|
| 22 |
cohere
|
| 23 |
mistralai
|
| 24 |
+
peft
|
| 25 |
+
mdit_plain
|
| 26 |
+
markdown_it
|
src/backend/model_operations.py
CHANGED
|
@@ -160,7 +160,7 @@ class SummaryGenerator:
|
|
| 160 |
using_replicate_api = False
|
| 161 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 162 |
using_pipeline = False
|
| 163 |
-
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5']
|
| 164 |
|
| 165 |
for replicate_api_model in replicate_api_models:
|
| 166 |
if replicate_api_model in self.model_id.lower():
|
|
@@ -325,7 +325,20 @@ class SummaryGenerator:
|
|
| 325 |
result = message.content[0].text
|
| 326 |
print(result)
|
| 327 |
return result
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
elif 'mistral-large' in self.model_id.lower():
|
| 330 |
api_key = os.environ["MISTRAL_API_KEY"]
|
| 331 |
client = Mistral(api_key=api_key)
|
|
@@ -554,14 +567,8 @@ class EvaluationModel:
|
|
| 554 |
for doc, summary in source_summary_pairs:
|
| 555 |
if util.is_summary_valid(summary):
|
| 556 |
try:
|
| 557 |
-
summary =
|
| 558 |
score = self.predict([(doc, summary)])[0]
|
| 559 |
-
# print(score)
|
| 560 |
-
# if score < 0.5:
|
| 561 |
-
# print(doc)
|
| 562 |
-
# print('-'*10)
|
| 563 |
-
# print(summary)
|
| 564 |
-
# print('='*20)
|
| 565 |
hem_scores.append(score)
|
| 566 |
sources.append(doc)
|
| 567 |
summaries.append(summary)
|
|
|
|
| 160 |
using_replicate_api = False
|
| 161 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 162 |
using_pipeline = False
|
| 163 |
+
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo']
|
| 164 |
|
| 165 |
for replicate_api_model in replicate_api_models:
|
| 166 |
if replicate_api_model in self.model_id.lower():
|
|
|
|
| 325 |
result = message.content[0].text
|
| 326 |
print(result)
|
| 327 |
return result
|
| 328 |
+
|
| 329 |
+
elif 'command-r' in self.model_id.lower():
|
| 330 |
+
co = cohere.Client(os.getenv('COHERE_API_TOKEN'))
|
| 331 |
+
response = co.chat(
|
| 332 |
+
chat_history=[
|
| 333 |
+
{"role": "SYSTEM", "message": system_prompt},
|
| 334 |
+
],
|
| 335 |
+
message=user_prompt,
|
| 336 |
+
)
|
| 337 |
+
result = response.text
|
| 338 |
+
print(result)
|
| 339 |
+
return result
|
| 340 |
+
|
| 341 |
+
|
| 342 |
elif 'mistral-large' in self.model_id.lower():
|
| 343 |
api_key = os.environ["MISTRAL_API_KEY"]
|
| 344 |
client = Mistral(api_key=api_key)
|
|
|
|
| 567 |
for doc, summary in source_summary_pairs:
|
| 568 |
if util.is_summary_valid(summary):
|
| 569 |
try:
|
| 570 |
+
summary = util.normalize_summary(summary)
|
| 571 |
score = self.predict([(doc, summary)])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
hem_scores.append(score)
|
| 573 |
sources.append(doc)
|
| 574 |
summaries.append(summary)
|
src/backend/util.py
CHANGED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def is_summary_valid(summary: str) -> bool:
|
| 2 |
"""
|
| 3 |
Checks if the summary is valid.
|
|
@@ -76,3 +81,12 @@ def format_results(model_name: str, revision: str, precision: str,
|
|
| 76 |
}
|
| 77 |
|
| 78 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from markdown_it import MarkdownIt
|
| 3 |
+
from mdit_plain.renderer import RendererPlain
|
| 4 |
+
|
| 5 |
+
|
| 6 |
def is_summary_valid(summary: str) -> bool:
|
| 7 |
"""
|
| 8 |
Checks if the summary is valid.
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
return results
|
| 84 |
+
|
| 85 |
+
parser = MarkdownIt(renderer_cls=RendererPlain)
|
| 86 |
+
|
| 87 |
+
def normalize_summary(summary: str) -> str:
|
| 88 |
+
summary = summary.replace('<bos>','').replace('<eos>','')
|
| 89 |
+
summary = parser.render(summary)
|
| 90 |
+
summary = summary.replace('*','')
|
| 91 |
+
summary = re.sub('\s{2,}', ' ', summary)
|
| 92 |
+
return summary
|