Commit ·
60fbaa9
1
Parent(s): f311c70
entrophy method
Browse files- handler.py +30 -9
handler.py
CHANGED
|
@@ -18,17 +18,38 @@ class EndpointHandler:
|
|
| 18 |
return self.predict(inputs)
|
| 19 |
|
| 20 |
def predict(self, text):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
with torch.no_grad():
|
| 26 |
outputs = self.model(**encoded_input)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def get_pipeline():
|
| 34 |
return EndpointHandler
|
|
|
|
| 18 |
return self.predict(inputs)
|
| 19 |
|
| 20 |
def predict(self, text):
|
| 21 |
+
# Tokenize and encode the input
|
| 22 |
+
encoded_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
|
| 23 |
+
|
| 24 |
+
# Get model prediction
|
| 25 |
with torch.no_grad():
|
| 26 |
outputs = self.model(**encoded_input)
|
| 27 |
+
logits = outputs.logits
|
| 28 |
+
|
| 29 |
+
# Get probabilities
|
| 30 |
+
probabilities = F.softmax(logits, dim=-1).squeeze().numpy()
|
| 31 |
+
|
| 32 |
+
# Get predicted class and confidence
|
| 33 |
+
predicted_class_idx = np.argmax(probabilities)
|
| 34 |
+
predicted_label = self.labels[predicted_class_idx]
|
| 35 |
+
confidence = probabilities[predicted_class_idx]
|
| 36 |
+
|
| 37 |
+
# Additional analysis
|
| 38 |
+
entropy = -np.sum(probabilities * np.log(probabilities + 1e-9))
|
| 39 |
+
max_prob_ratio = np.max(probabilities) / np.sort(probabilities)[-2]
|
| 40 |
+
|
| 41 |
+
# Adjust confidence based on entropy and probability ratio
|
| 42 |
+
adjusted_confidence = confidence * (1 - entropy/np.log(len(probabilities))) * max_prob_ratio
|
| 43 |
+
|
| 44 |
+
# Lower the confidence for very short inputs
|
| 45 |
+
if len(text.split()) < 4:
|
| 46 |
+
adjusted_confidence *= 0.5
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
"label": predicted_label,
|
| 50 |
+
"score": float(adjusted_confidence),
|
| 51 |
+
"raw_scores": {label: float(prob) for label, prob in zip(self.labels.values(), probabilities)}
|
| 52 |
+
}
|
| 53 |
|
| 54 |
def get_pipeline():
|
| 55 |
return EndpointHandler
|