Spaces:
Sleeping
Sleeping
Merwan6
commited on
Commit
·
11204e4
1
Parent(s):
eaf4ff4
modif
Browse files- .DS_Store +0 -0
- scripts/inference.py +25 -16
- scripts/metric.py +2 -3
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
scripts/inference.py
CHANGED
|
@@ -31,36 +31,45 @@ def zero_shot_inference(text):
|
|
| 31 |
|
| 32 |
def few_shot_inference(text):
|
| 33 |
"""
|
| 34 |
-
|
| 35 |
|
| 36 |
Args:
|
| 37 |
text (str): Texte à classifier.
|
| 38 |
|
| 39 |
Returns:
|
| 40 |
tuple:
|
| 41 |
-
- str: Label prédit.
|
| 42 |
-
- dict:
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
examples = [
|
| 48 |
("The president met the UN delegation to discuss global peace.", "World"),
|
| 49 |
("The football team won their match last night.", "Sports"),
|
| 50 |
("The company reported a big profit this quarter.", "Business"),
|
| 51 |
("New research in AI shows promising results.", "Sci/Tech")
|
| 52 |
]
|
| 53 |
-
|
| 54 |
-
#Construction du prompt avec des exemples
|
| 55 |
-
prompt = ""
|
| 56 |
-
for example_text, example_label in examples:
|
| 57 |
-
prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
|
| 58 |
-
prompt += f"Text: {text}\nLabel:"
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return prediction, scores
|
| 65 |
|
| 66 |
|
|
|
|
| 31 |
|
| 32 |
def few_shot_inference(text):
|
| 33 |
"""
|
| 34 |
+
Classification few-shot avec FLAN-T5 : génère uniquement le label (World, Sports, etc.).
|
| 35 |
|
| 36 |
Args:
|
| 37 |
text (str): Texte à classifier.
|
| 38 |
|
| 39 |
Returns:
|
| 40 |
tuple:
|
| 41 |
+
- str: Label prédit (nettoyé et validé).
|
| 42 |
+
- dict: Détails du texte généré brut.
|
| 43 |
"""
|
| 44 |
+
model_name = "google/flan-t5-small"
|
| 45 |
+
classifier = pipeline("text2text-generation", model=model_name, max_new_tokens=10)
|
| 46 |
+
|
| 47 |
examples = [
|
| 48 |
("The president met the UN delegation to discuss global peace.", "World"),
|
| 49 |
("The football team won their match last night.", "Sports"),
|
| 50 |
("The company reported a big profit this quarter.", "Business"),
|
| 51 |
("New research in AI shows promising results.", "Sci/Tech")
|
| 52 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# Prompt few-shot
|
| 55 |
+
prompt = "Classify the following text into one of the following categories: World, Sports, Business, Sci/Tech.\n\n"
|
| 56 |
+
for ex_text, ex_label in examples:
|
| 57 |
+
prompt += f"Text: {ex_text}\nCategory: {ex_label}\n\n"
|
| 58 |
+
prompt += f"Text: {text}\nCategory:"
|
| 59 |
+
|
| 60 |
+
# Génération
|
| 61 |
+
output = classifier(prompt)[0]["generated_text"].strip()
|
| 62 |
+
|
| 63 |
+
# Nettoyage du label
|
| 64 |
+
output_clean = output.split()[0].rstrip(".").capitalize() # ex : "sci/tech." → "Sci/tech"
|
| 65 |
+
|
| 66 |
+
# Mapping pour être sûr que ça correspond à une catégorie connue
|
| 67 |
+
candidate_labels = ["World", "Sports", "Business", "Sci/Tech"]
|
| 68 |
+
prediction = next((label for label in candidate_labels if label.lower() in output_clean.lower()), "Unknown")
|
| 69 |
+
|
| 70 |
+
# Fausse distribution (1.0 pour la classe prédite, 0.0 pour les autres)
|
| 71 |
+
scores = {label: 1.0 if label == prediction else 0.0 for label in candidate_labels}
|
| 72 |
+
|
| 73 |
return prediction, scores
|
| 74 |
|
| 75 |
|
scripts/metric.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import time
|
| 2 |
-
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
from tqdm import tqdm # ✅ Ajout ici
|
| 5 |
from datasets import load_dataset
|
|
@@ -20,7 +19,7 @@ models_to_evaluate = {
|
|
| 20 |
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
| 21 |
|
| 22 |
# Charger un sous-ensemble du jeu de test AG News
|
| 23 |
-
dataset = load_dataset("ag_news", split="test[:
|
| 24 |
|
| 25 |
def evaluate_model(name, inference_func):
|
| 26 |
print(f"\n🔍 Évaluation du modèle : {name}")
|
|
@@ -56,7 +55,7 @@ def evaluate_model(name, inference_func):
|
|
| 56 |
f1 = f1_score(true_labels, pred_labels, average='weighted')
|
| 57 |
prec = precision_score(true_labels, pred_labels, average='weighted')
|
| 58 |
rec = recall_score(true_labels, pred_labels, average='weighted')
|
| 59 |
-
loss = log_loss(true_labels, all_probs)
|
| 60 |
|
| 61 |
print(f"✅ Résultats {name} :")
|
| 62 |
print(f"- Accuracy : {acc:.4f}")
|
|
|
|
| 1 |
import time
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
from tqdm import tqdm # ✅ Ajout ici
|
| 4 |
from datasets import load_dataset
|
|
|
|
| 19 |
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
| 20 |
|
| 21 |
# Charger un sous-ensemble du jeu de test AG News
|
| 22 |
+
dataset = load_dataset("ag_news", split="test[:3]")
|
| 23 |
|
| 24 |
def evaluate_model(name, inference_func):
|
| 25 |
print(f"\n🔍 Évaluation du modèle : {name}")
|
|
|
|
| 55 |
f1 = f1_score(true_labels, pred_labels, average='weighted')
|
| 56 |
prec = precision_score(true_labels, pred_labels, average='weighted')
|
| 57 |
rec = recall_score(true_labels, pred_labels, average='weighted')
|
| 58 |
+
loss = log_loss(true_labels, all_probs, labels=[0, 1, 2, 3])
|
| 59 |
|
| 60 |
print(f"✅ Résultats {name} :")
|
| 61 |
print(f"- Accuracy : {acc:.4f}")
|