simple_nli / app.py
bkjcb's picture
Fix annotations
8e7865e
#!/usr/bin/env python
"""
Gradio app for NLI zero-shot classification of survey free-text responses.
"""
import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline
from sklearn.metrics import precision_recall_fscore_support, precision_recall_curve
import tempfile
import os
import traceback
TEMPLATE = "Ce texte exprime {}."
MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
print("Chargement du modèle NLI...")
import torch
_device = 0 if torch.cuda.is_available() else -1
classifier = pipeline("zero-shot-classification", model=MODEL_NAME, device=_device)
print("Modèle chargé.")
state = {}
def load_file(file):
if file is None:
raise gr.Error("Aucun fichier uploadé.")
path = file if isinstance(file, str) else file.name
if path.endswith(".csv"):
return pd.read_csv(path)
return pd.read_excel(path)
def on_upload(file):
try:
df = load_file(file)
cols = list(df.columns)
state["raw_df"] = df
print(f"Fichier chargé: {len(df)} lignes, colonnes: {cols}")
return (
gr.CheckboxGroup(choices=cols, value=[]),
gr.Dropdown(choices=["(aucune)"] + cols, value="(aucune)"),
f"{len(df)} lignes, {len(cols)} colonnes chargées.",
)
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Erreur de chargement: {e}")
def prepare_texts(df, text_cols):
if len(text_cols) == 1:
out = df[[text_cols[0]]].copy()
out.columns = ["text"]
out["source"] = text_cols[0]
else:
out = df[text_cols].melt(var_name="source", value_name="text")
out["text"] = out["text"].fillna("").astype(str).str.strip()
out = out[out["text"] != ""].reset_index(drop=True)
return out
# ---------------------------------------------------------------------------
# Annotation helpers
# ---------------------------------------------------------------------------
def _save_annotations_from_df(hyp, annot_df):
"""Write annotations from the displayed dataframe back to state. Returns count."""
sample = state.get("sample")
if sample is None or hyp is None:
return 0
annot_col = f"annot__{hyp}"
score_col = f"score__{hyp}"
sorted_idx = sample[score_col].sort_values(ascending=False).index
if annot_df is None or len(annot_df) == 0:
return 0
annotations = annot_df["Annotation"].values
for i, idx in enumerate(sorted_idx):
if i >= len(annotations):
break
val = annotations[i]
val_str = str(val).strip() if pd.notna(val) else ""
if val_str in ("0", "1", "0.0", "1.0"):
sample.at[idx, annot_col] = int(float(val_str))
else:
sample.at[idx, annot_col] = np.nan
state["sample"] = sample
return int(sample[annot_col].notna().sum())
def get_annotation_view(hyp):
sample = state.get("sample")
if sample is None or not hyp:
return pd.DataFrame(columns=["Texte", "Score", "Prédiction", "Annotation"])
view = sample[["text", f"score__{hyp}", f"pred__{hyp}", f"annot__{hyp}"]].copy()
view.columns = ["Texte", "Score", "Prédiction", "Annotation"]
view = view.sort_values("Score", ascending=False).reset_index(drop=True)
# Replace NaN with empty string for cleaner display
view["Annotation"] = view["Annotation"].apply(
lambda x: int(x) if pd.notna(x) else ""
)
return view
# ---------------------------------------------------------------------------
# Classification
# ---------------------------------------------------------------------------
def run_classification(file, text_cols, group_col, hypotheses_text, threshold,
progress=gr.Progress()):
print(f"run_classification called: text_cols={text_cols}, threshold={threshold}")
try:
if not text_cols:
raise gr.Error("Sélectionne au moins une colonne de texte.")
hypotheses_raw = [h.strip() for h in hypotheses_text.strip().split("\n") if h.strip()]
if not hypotheses_raw:
raise gr.Error("Aucune hypothèse saisie.")
df = state.get("raw_df")
if df is None:
raise gr.Error("Recharge le fichier.")
sample = prepare_texts(df, text_cols)
if group_col and group_col != "(aucune)":
if len(text_cols) == 1:
sample["group"] = df[group_col].values[:len(sample)]
else:
groups = []
for col in text_cols:
groups.extend(df[group_col].values)
sample["group"] = groups[:len(sample)]
else:
sample["group"] = "Tous"
texts = list(sample["text"])
n_hyp = len(hypotheses_raw)
print(f"Classifying {len(texts)} texts x {n_hyp} hypotheses")
score_cols = {}
for i, hyp in enumerate(hypotheses_raw):
progress(i / n_hyp, desc=f"Hypothèse {i+1}/{n_hyp}: {hyp[:40]}...")
print(f" -> {hyp}")
results = classifier(
texts,
candidate_labels=[hyp],
hypothesis_template=TEMPLATE,
multi_label=False,
batch_size=32,
)
scores = np.array([r["scores"][0] for r in results])
score_cols[hyp] = scores
progress(1.0, desc="Terminé.")
for hyp in hypotheses_raw:
sample[f"score__{hyp}"] = score_cols[hyp]
sample[f"pred__{hyp}"] = (score_cols[hyp] >= threshold).astype(int)
sample[f"annot__{hyp}"] = np.nan
state["sample"] = sample
state["hypotheses"] = hypotheses_raw
state["threshold"] = threshold
cross = build_cross_table(sample, hypotheses_raw, threshold)
state["cross_table"] = cross
annot_df = get_annotation_view(hypotheses_raw[0])
print(f"Done.\n{cross}")
return (
gr.Dataframe(value=cross, visible=True),
gr.Radio(choices=hypotheses_raw, value=hypotheses_raw[0]),
annot_df,
gr.Column(visible=True),
"",
)
except gr.Error:
raise
except Exception as e:
traceback.print_exc()
raise gr.Error(f"Erreur: {e}")
def build_cross_table(sample, hypotheses, threshold=None):
"""Build cross-table using per-hypothesis thresholds if available."""
groups = sorted(sample["group"].unique())
rows = []
default_thr = state.get("threshold", 0.5)
per_thr = state.get("thresholds", {})
for hyp in hypotheses:
sc = f"score__{hyp}"
t = per_thr.get(hyp, threshold if threshold is not None else default_thr)
row = {"Hypothèse": hyp}
for g in groups:
mask = sample["group"] == g
row[g] = round((sample.loc[mask, sc] >= t).mean() * 100, 1)
row["Tous"] = round((sample[sc] >= t).mean() * 100, 1)
rows.append(row)
return pd.DataFrame(rows)
def on_hypothesis_change(new_hyp, prev_hyp, annot_df):
"""Auto-save annotations for previous hypothesis, then load new one."""
if prev_hyp and annot_df is not None and len(annot_df) > 0:
n = _save_annotations_from_df(prev_hyp, annot_df)
print(f"Auto-saved {n} annotations for '{prev_hyp}'")
return get_annotation_view(new_hyp), new_hyp
def compute_performance(hyp, annot_df):
"""Save current annotations first, then compute metrics."""
if hyp and annot_df is not None and len(annot_df) > 0:
_save_annotations_from_df(hyp, annot_df)
sample = state.get("sample")
if sample is None:
return "Pas de données.", None
annot_col = f"annot__{hyp}"
score_col = f"score__{hyp}"
mask = sample[annot_col].notna()
n = mask.sum()
if n < 5:
return f"Seulement {n} annotations. Minimum 5 pour calculer.", None
y_true = sample.loc[mask, annot_col].astype(int).values
scores = sample.loc[mask, score_col].values
if len(set(y_true)) < 2:
return "Il faut au moins un exemple positif et un négatif.", None
thr = state.get("threshold", 0.5)
per_thr = state.get("thresholds", {})
current_thr = per_thr.get(hyp, thr)
y_pred = (scores >= current_thr).astype(int)
p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
precisions, recalls, thr_pr = precision_recall_curve(y_true, scores)
f1s = 2 * precisions * recalls / (precisions + recalls + 1e-10)
best_idx = np.argmax(f1s)
best_thr = thr_pr[best_idx] if best_idx < len(thr_pr) else current_thr
best_f1 = f1s[best_idx]
best_p = precisions[best_idx]
best_r = recalls[best_idx]
report = (
f"### '{hyp}' — {n} annotations\n\n"
f"**Seuil actuel ({current_thr:.2f})** : P={p:.2f} R={r:.2f} F1={f1:.2f}\n\n"
f"**Seuil optimal ({best_thr:.2f})** : P={best_p:.2f} R={best_r:.2f} F1={best_f1:.2f}"
)
return report, round(best_thr, 3)
def apply_optimal_threshold(hyp, annot_df):
report, best_thr = compute_performance(hyp, annot_df)
if best_thr is None:
return report, state.get("cross_table", pd.DataFrame())
sample = state["sample"]
hypotheses = state["hypotheses"]
if "thresholds" not in state:
state["thresholds"] = {h: state.get("threshold", 0.5) for h in hypotheses}
state["thresholds"][hyp] = best_thr
cross = build_cross_table(sample, hypotheses)
state["cross_table"] = cross
report += f"\n\nTableau croisé recalculé avec seuil {best_thr:.3f} pour cette hypothèse."
return report, cross
def export_xlsx(hyp, annot_df):
# Auto-save before export
if hyp and annot_df is not None and len(annot_df) > 0:
_save_annotations_from_df(hyp, annot_df)
sample = state.get("sample")
if sample is None:
raise gr.Error("Rien à exporter.")
path = os.path.join(tempfile.gettempdir(), "classification_nli.xlsx")
with pd.ExcelWriter(path, engine="openpyxl") as writer:
sample.to_excel(writer, sheet_name="Données", index=False)
cross = state.get("cross_table")
if cross is not None:
cross.to_excel(writer, sheet_name="Tableau croisé", index=False)
thresholds = state.get("thresholds", {})
if thresholds:
pd.DataFrame([{"Hypothèse": k, "Seuil": v} for k, v in thresholds.items()]).to_excel(
writer, sheet_name="Seuils", index=False)
return path
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Classification NLI") as app:
# Hidden state to track which hypothesis was previously selected
prev_hyp_state = gr.State(value=None)
gr.Markdown("# Classification NLI zero-shot de textes libres")
gr.Markdown(
"Upload un fichier (xlsx/csv), choisis les colonnes de texte, "
"saisis des hypothèses (une par ligne), lance la classification. "
"Annote ensuite pour calibrer les seuils."
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Fichier (xlsx ou csv)", file_types=[".xlsx", ".csv"])
upload_status = gr.Textbox(label="Statut", interactive=False)
text_cols = gr.CheckboxGroup(label="Colonnes de texte", choices=[])
group_col = gr.Dropdown(label="Colonne de regroupement (optionnel)",
choices=["(aucune)"], value="(aucune)")
with gr.Column(scale=1):
hypotheses_input = gr.Textbox(
label="Hypothèses (une par ligne)", lines=10,
placeholder="un manque de soutien politique\nun manque de moyens humains\n...",
)
threshold_input = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
label="Seuil de classification")
run_btn = gr.Button("Lancer la classification", variant="primary")
cross_table_output = gr.Dataframe(label="Tableau croisé (% affirmatifs)", visible=False)
with gr.Column(visible=False) as annot_section:
gr.Markdown("## Annotation & Calibration")
gr.Markdown(
"Clique sur une hypothèse, puis annote les textes : "
"mets **1** (affirme) ou **0** (non) dans la colonne Annotation. "
"Les annotations sont sauvegardées automatiquement quand tu changes d'hypothèse."
)
hyp_selector = gr.Radio(label="Hypothèse", choices=[], interactive=True)
annot_table = gr.Dataframe(
label="Textes classifiés (triés par score décroissant)",
interactive=True,
)
with gr.Row():
perf_btn = gr.Button("Calculer performance & seuil optimal", variant="secondary")
apply_btn = gr.Button("Appliquer seuil optimal & recalculer", variant="secondary")
perf_output = gr.Markdown("")
with gr.Row():
export_btn = gr.Button("Exporter tout (.xlsx)", variant="primary")
export_file = gr.File(label="Fichier exporté")
# --- Events ---
file_input.change(on_upload, inputs=[file_input],
outputs=[text_cols, group_col, upload_status])
run_btn.click(
run_classification,
inputs=[file_input, text_cols, group_col, hypotheses_input, threshold_input],
outputs=[cross_table_output, hyp_selector, annot_table, annot_section, perf_output],
)
# When hypothesis changes: auto-save previous annotations, load new view
hyp_selector.change(
on_hypothesis_change,
inputs=[hyp_selector, prev_hyp_state, annot_table],
outputs=[annot_table, prev_hyp_state],
)
perf_btn.click(
lambda hyp, df: compute_performance(hyp, df)[0],
inputs=[hyp_selector, annot_table],
outputs=[perf_output],
)
apply_btn.click(
apply_optimal_threshold,
inputs=[hyp_selector, annot_table],
outputs=[perf_output, cross_table_output],
)
export_btn.click(
export_xlsx,
inputs=[hyp_selector, annot_table],
outputs=[export_file],
)
if __name__ == "__main__":
app.launch(ssr_mode=False)