Update tasks/text.py
Browse files- tasks/text.py +16 -11
tasks/text.py
CHANGED
|
@@ -12,6 +12,16 @@ from safetensors.torch import load_file
|
|
| 12 |
from .utils.evaluation import TextEvaluationRequest
|
| 13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
router = APIRouter()
|
| 16 |
|
| 17 |
DESCRIPTION = "GTE Architecture"
|
|
@@ -70,16 +80,6 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 70 |
true_labels = test_dataset["label"]
|
| 71 |
texts = test_dataset["quote"]
|
| 72 |
|
| 73 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 74 |
-
|
| 75 |
-
model_repo = "elucidator8918/frugal-ai-text"
|
| 76 |
-
model = AutoBertClassifier(num_labels=8)
|
| 77 |
-
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
| 78 |
-
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 79 |
-
|
| 80 |
-
model = model.to(device)
|
| 81 |
-
model.eval()
|
| 82 |
-
|
| 83 |
# Start tracking emissions
|
| 84 |
tracker.start()
|
| 85 |
tracker.start_task("inference")
|
|
@@ -94,6 +94,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 94 |
truncation=True,
|
| 95 |
padding=True,
|
| 96 |
return_tensors="pt",
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
with torch.no_grad():
|
|
@@ -101,7 +102,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 101 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
| 102 |
outputs = model(text_input_ids, text_attention_mask)
|
| 103 |
predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
|
| 104 |
-
|
| 105 |
#--------------------------------------------------------------------------------------------
|
| 106 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 107 |
#--------------------------------------------------------------------------------------------
|
|
@@ -111,6 +112,8 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 111 |
|
| 112 |
# Calculate accuracy
|
| 113 |
accuracy = accuracy_score(true_labels, predictions)
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Prepare results dictionary
|
| 116 |
results = {
|
|
@@ -129,5 +132,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 129 |
"test_seed": request.test_seed
|
| 130 |
}
|
| 131 |
}
|
|
|
|
|
|
|
| 132 |
|
| 133 |
return results
|
|
|
|
| 12 |
from .utils.evaluation import TextEvaluationRequest
|
| 13 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
| 14 |
|
| 15 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 16 |
+
|
| 17 |
+
model_repo = "elucidator8918/frugal-ai-text"
|
| 18 |
+
model = AutoBertClassifier(num_labels=8)
|
| 19 |
+
model.load_state_dict(load_file(hf_hub_download(repo_id=model_repo, filename="model.safetensors")))
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 21 |
+
|
| 22 |
+
model = model.to(device)
|
| 23 |
+
model.eval()
|
| 24 |
+
|
| 25 |
router = APIRouter()
|
| 26 |
|
| 27 |
DESCRIPTION = "GTE Architecture"
|
|
|
|
| 80 |
true_labels = test_dataset["label"]
|
| 81 |
texts = test_dataset["quote"]
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Start tracking emissions
|
| 84 |
tracker.start()
|
| 85 |
tracker.start_task("inference")
|
|
|
|
| 94 |
truncation=True,
|
| 95 |
padding=True,
|
| 96 |
return_tensors="pt",
|
| 97 |
+
max_length=256
|
| 98 |
)
|
| 99 |
|
| 100 |
with torch.no_grad():
|
|
|
|
| 102 |
text_attention_mask = text_encoding["attention_mask"].to(device)
|
| 103 |
outputs = model(text_input_ids, text_attention_mask)
|
| 104 |
predictions = torch.argmax(outputs.logits, dim=1).cpu().numpy()
|
| 105 |
+
|
| 106 |
#--------------------------------------------------------------------------------------------
|
| 107 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 108 |
#--------------------------------------------------------------------------------------------
|
|
|
|
| 112 |
|
| 113 |
# Calculate accuracy
|
| 114 |
accuracy = accuracy_score(true_labels, predictions)
|
| 115 |
+
|
| 116 |
+
print(f"Accuracy = {accuracy}")
|
| 117 |
|
| 118 |
# Prepare results dictionary
|
| 119 |
results = {
|
|
|
|
| 132 |
"test_seed": request.test_seed
|
| 133 |
}
|
| 134 |
}
|
| 135 |
+
|
| 136 |
+
print(results)
|
| 137 |
|
| 138 |
return results
|