Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import outlines
|
|
| 7 |
import pandas as pd
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
-
from outlines import
|
| 11 |
from peft import PeftConfig, PeftModel
|
| 12 |
from pydantic import BaseModel, ConfigDict
|
| 13 |
from transformers import (
|
|
@@ -137,9 +137,9 @@ def label_single_response_with_model(model_id, story, question, criteria, respon
|
|
| 137 |
|
| 138 |
else:
|
| 139 |
model = get_outlines_model(model_id, DEVICE_MAP, QUANTIZATION_BITS)
|
| 140 |
-
generator =
|
| 141 |
result = generator(prompt)
|
| 142 |
-
return result
|
| 143 |
|
| 144 |
|
| 145 |
@spaces.GPU
|
|
@@ -161,9 +161,9 @@ def label_multi_responses_with_model(model_id, story, question, criteria, respon
|
|
| 161 |
scores = [str(cls) for cls in torch.argmax(logits, dim=1).tolist()]
|
| 162 |
else:
|
| 163 |
model = get_outlines_model(model_id, DEVICE_MAP, QUANTIZATION_BITS)
|
| 164 |
-
generator =
|
| 165 |
results = [generator(p) for p in prompts]
|
| 166 |
-
scores = [r
|
| 167 |
|
| 168 |
df["score"] = scores
|
| 169 |
return df
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
+
from outlines import Generator
|
| 11 |
from peft import PeftConfig, PeftModel
|
| 12 |
from pydantic import BaseModel, ConfigDict
|
| 13 |
from transformers import (
|
|
|
|
| 137 |
|
| 138 |
else:
|
| 139 |
model = get_outlines_model(model_id, DEVICE_MAP, QUANTIZATION_BITS)
|
| 140 |
+
generator = Generator(model, ResponseModel) # pass schema
|
| 141 |
result = generator(prompt)
|
| 142 |
+
return result.score
|
| 143 |
|
| 144 |
|
| 145 |
@spaces.GPU
|
|
|
|
| 161 |
scores = [str(cls) for cls in torch.argmax(logits, dim=1).tolist()]
|
| 162 |
else:
|
| 163 |
model = get_outlines_model(model_id, DEVICE_MAP, QUANTIZATION_BITS)
|
| 164 |
+
generator = Generator(model, ResponseModel)
|
| 165 |
results = [generator(p) for p in prompts]
|
| 166 |
+
scores = [r.score for r in results]
|
| 167 |
|
| 168 |
df["score"] = scores
|
| 169 |
return df
|