#!/usr/bin/env python """ Gradio app for NLI zero-shot classification of survey free-text responses. """ import gradio as gr import pandas as pd import numpy as np from transformers import pipeline from sklearn.metrics import precision_recall_fscore_support, precision_recall_curve import tempfile import os import traceback TEMPLATE = "Ce texte exprime {}." MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" print("Chargement du modèle NLI...") import torch _device = 0 if torch.cuda.is_available() else -1 classifier = pipeline("zero-shot-classification", model=MODEL_NAME, device=_device) print("Modèle chargé.") state = {} def load_file(file): if file is None: raise gr.Error("Aucun fichier uploadé.") path = file if isinstance(file, str) else file.name if path.endswith(".csv"): return pd.read_csv(path) return pd.read_excel(path) def on_upload(file): try: df = load_file(file) cols = list(df.columns) state["raw_df"] = df print(f"Fichier chargé: {len(df)} lignes, colonnes: {cols}") return ( gr.CheckboxGroup(choices=cols, value=[]), gr.Dropdown(choices=["(aucune)"] + cols, value="(aucune)"), f"{len(df)} lignes, {len(cols)} colonnes chargées.", ) except Exception as e: traceback.print_exc() raise gr.Error(f"Erreur de chargement: {e}") def prepare_texts(df, text_cols): if len(text_cols) == 1: out = df[[text_cols[0]]].copy() out.columns = ["text"] out["source"] = text_cols[0] else: out = df[text_cols].melt(var_name="source", value_name="text") out["text"] = out["text"].fillna("").astype(str).str.strip() out = out[out["text"] != ""].reset_index(drop=True) return out # --------------------------------------------------------------------------- # Annotation helpers # --------------------------------------------------------------------------- def _save_annotations_from_df(hyp, annot_df): """Write annotations from the displayed dataframe back to state. Returns count.""" sample = state.get("sample") if sample is None or hyp is None: return 0 annot_col = f"annot__{hyp}" score_col = f"score__{hyp}" sorted_idx = sample[score_col].sort_values(ascending=False).index if annot_df is None or len(annot_df) == 0: return 0 annotations = annot_df["Annotation"].values for i, idx in enumerate(sorted_idx): if i >= len(annotations): break val = annotations[i] val_str = str(val).strip() if pd.notna(val) else "" if val_str in ("0", "1", "0.0", "1.0"): sample.at[idx, annot_col] = int(float(val_str)) else: sample.at[idx, annot_col] = np.nan state["sample"] = sample return int(sample[annot_col].notna().sum()) def get_annotation_view(hyp): sample = state.get("sample") if sample is None or not hyp: return pd.DataFrame(columns=["Texte", "Score", "Prédiction", "Annotation"]) view = sample[["text", f"score__{hyp}", f"pred__{hyp}", f"annot__{hyp}"]].copy() view.columns = ["Texte", "Score", "Prédiction", "Annotation"] view = view.sort_values("Score", ascending=False).reset_index(drop=True) # Replace NaN with empty string for cleaner display view["Annotation"] = view["Annotation"].apply( lambda x: int(x) if pd.notna(x) else "" ) return view # --------------------------------------------------------------------------- # Classification # --------------------------------------------------------------------------- def run_classification(file, text_cols, group_col, hypotheses_text, threshold, progress=gr.Progress()): print(f"run_classification called: text_cols={text_cols}, threshold={threshold}") try: if not text_cols: raise gr.Error("Sélectionne au moins une colonne de texte.") hypotheses_raw = [h.strip() for h in hypotheses_text.strip().split("\n") if h.strip()] if not hypotheses_raw: raise gr.Error("Aucune hypothèse saisie.") df = state.get("raw_df") if df is None: raise gr.Error("Recharge le fichier.") sample = prepare_texts(df, text_cols) if group_col and group_col != "(aucune)": if len(text_cols) == 1: sample["group"] = df[group_col].values[:len(sample)] else: groups = [] for col in text_cols: groups.extend(df[group_col].values) sample["group"] = groups[:len(sample)] else: sample["group"] = "Tous" texts = list(sample["text"]) n_hyp = len(hypotheses_raw) print(f"Classifying {len(texts)} texts x {n_hyp} hypotheses") score_cols = {} for i, hyp in enumerate(hypotheses_raw): progress(i / n_hyp, desc=f"Hypothèse {i+1}/{n_hyp}: {hyp[:40]}...") print(f" -> {hyp}") results = classifier( texts, candidate_labels=[hyp], hypothesis_template=TEMPLATE, multi_label=False, batch_size=32, ) scores = np.array([r["scores"][0] for r in results]) score_cols[hyp] = scores progress(1.0, desc="Terminé.") for hyp in hypotheses_raw: sample[f"score__{hyp}"] = score_cols[hyp] sample[f"pred__{hyp}"] = (score_cols[hyp] >= threshold).astype(int) sample[f"annot__{hyp}"] = np.nan state["sample"] = sample state["hypotheses"] = hypotheses_raw state["threshold"] = threshold cross = build_cross_table(sample, hypotheses_raw, threshold) state["cross_table"] = cross annot_df = get_annotation_view(hypotheses_raw[0]) print(f"Done.\n{cross}") return ( gr.Dataframe(value=cross, visible=True), gr.Radio(choices=hypotheses_raw, value=hypotheses_raw[0]), annot_df, gr.Column(visible=True), "", ) except gr.Error: raise except Exception as e: traceback.print_exc() raise gr.Error(f"Erreur: {e}") def build_cross_table(sample, hypotheses, threshold=None): """Build cross-table using per-hypothesis thresholds if available.""" groups = sorted(sample["group"].unique()) rows = [] default_thr = state.get("threshold", 0.5) per_thr = state.get("thresholds", {}) for hyp in hypotheses: sc = f"score__{hyp}" t = per_thr.get(hyp, threshold if threshold is not None else default_thr) row = {"Hypothèse": hyp} for g in groups: mask = sample["group"] == g row[g] = round((sample.loc[mask, sc] >= t).mean() * 100, 1) row["Tous"] = round((sample[sc] >= t).mean() * 100, 1) rows.append(row) return pd.DataFrame(rows) def on_hypothesis_change(new_hyp, prev_hyp, annot_df): """Auto-save annotations for previous hypothesis, then load new one.""" if prev_hyp and annot_df is not None and len(annot_df) > 0: n = _save_annotations_from_df(prev_hyp, annot_df) print(f"Auto-saved {n} annotations for '{prev_hyp}'") return get_annotation_view(new_hyp), new_hyp def compute_performance(hyp, annot_df): """Save current annotations first, then compute metrics.""" if hyp and annot_df is not None and len(annot_df) > 0: _save_annotations_from_df(hyp, annot_df) sample = state.get("sample") if sample is None: return "Pas de données.", None annot_col = f"annot__{hyp}" score_col = f"score__{hyp}" mask = sample[annot_col].notna() n = mask.sum() if n < 5: return f"Seulement {n} annotations. Minimum 5 pour calculer.", None y_true = sample.loc[mask, annot_col].astype(int).values scores = sample.loc[mask, score_col].values if len(set(y_true)) < 2: return "Il faut au moins un exemple positif et un négatif.", None thr = state.get("threshold", 0.5) per_thr = state.get("thresholds", {}) current_thr = per_thr.get(hyp, thr) y_pred = (scores >= current_thr).astype(int) p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0) precisions, recalls, thr_pr = precision_recall_curve(y_true, scores) f1s = 2 * precisions * recalls / (precisions + recalls + 1e-10) best_idx = np.argmax(f1s) best_thr = thr_pr[best_idx] if best_idx < len(thr_pr) else current_thr best_f1 = f1s[best_idx] best_p = precisions[best_idx] best_r = recalls[best_idx] report = ( f"### '{hyp}' — {n} annotations\n\n" f"**Seuil actuel ({current_thr:.2f})** : P={p:.2f} R={r:.2f} F1={f1:.2f}\n\n" f"**Seuil optimal ({best_thr:.2f})** : P={best_p:.2f} R={best_r:.2f} F1={best_f1:.2f}" ) return report, round(best_thr, 3) def apply_optimal_threshold(hyp, annot_df): report, best_thr = compute_performance(hyp, annot_df) if best_thr is None: return report, state.get("cross_table", pd.DataFrame()) sample = state["sample"] hypotheses = state["hypotheses"] if "thresholds" not in state: state["thresholds"] = {h: state.get("threshold", 0.5) for h in hypotheses} state["thresholds"][hyp] = best_thr cross = build_cross_table(sample, hypotheses) state["cross_table"] = cross report += f"\n\nTableau croisé recalculé avec seuil {best_thr:.3f} pour cette hypothèse." return report, cross def export_xlsx(hyp, annot_df): # Auto-save before export if hyp and annot_df is not None and len(annot_df) > 0: _save_annotations_from_df(hyp, annot_df) sample = state.get("sample") if sample is None: raise gr.Error("Rien à exporter.") path = os.path.join(tempfile.gettempdir(), "classification_nli.xlsx") with pd.ExcelWriter(path, engine="openpyxl") as writer: sample.to_excel(writer, sheet_name="Données", index=False) cross = state.get("cross_table") if cross is not None: cross.to_excel(writer, sheet_name="Tableau croisé", index=False) thresholds = state.get("thresholds", {}) if thresholds: pd.DataFrame([{"Hypothèse": k, "Seuil": v} for k, v in thresholds.items()]).to_excel( writer, sheet_name="Seuils", index=False) return path # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="Classification NLI") as app: # Hidden state to track which hypothesis was previously selected prev_hyp_state = gr.State(value=None) gr.Markdown("# Classification NLI zero-shot de textes libres") gr.Markdown( "Upload un fichier (xlsx/csv), choisis les colonnes de texte, " "saisis des hypothèses (une par ligne), lance la classification. " "Annote ensuite pour calibrer les seuils." ) with gr.Row(): with gr.Column(scale=1): file_input = gr.File(label="Fichier (xlsx ou csv)", file_types=[".xlsx", ".csv"]) upload_status = gr.Textbox(label="Statut", interactive=False) text_cols = gr.CheckboxGroup(label="Colonnes de texte", choices=[]) group_col = gr.Dropdown(label="Colonne de regroupement (optionnel)", choices=["(aucune)"], value="(aucune)") with gr.Column(scale=1): hypotheses_input = gr.Textbox( label="Hypothèses (une par ligne)", lines=10, placeholder="un manque de soutien politique\nun manque de moyens humains\n...", ) threshold_input = gr.Slider(0.1, 0.95, value=0.5, step=0.05, label="Seuil de classification") run_btn = gr.Button("Lancer la classification", variant="primary") cross_table_output = gr.Dataframe(label="Tableau croisé (% affirmatifs)", visible=False) with gr.Column(visible=False) as annot_section: gr.Markdown("## Annotation & Calibration") gr.Markdown( "Clique sur une hypothèse, puis annote les textes : " "mets **1** (affirme) ou **0** (non) dans la colonne Annotation. " "Les annotations sont sauvegardées automatiquement quand tu changes d'hypothèse." ) hyp_selector = gr.Radio(label="Hypothèse", choices=[], interactive=True) annot_table = gr.Dataframe( label="Textes classifiés (triés par score décroissant)", interactive=True, ) with gr.Row(): perf_btn = gr.Button("Calculer performance & seuil optimal", variant="secondary") apply_btn = gr.Button("Appliquer seuil optimal & recalculer", variant="secondary") perf_output = gr.Markdown("") with gr.Row(): export_btn = gr.Button("Exporter tout (.xlsx)", variant="primary") export_file = gr.File(label="Fichier exporté") # --- Events --- file_input.change(on_upload, inputs=[file_input], outputs=[text_cols, group_col, upload_status]) run_btn.click( run_classification, inputs=[file_input, text_cols, group_col, hypotheses_input, threshold_input], outputs=[cross_table_output, hyp_selector, annot_table, annot_section, perf_output], ) # When hypothesis changes: auto-save previous annotations, load new view hyp_selector.change( on_hypothesis_change, inputs=[hyp_selector, prev_hyp_state, annot_table], outputs=[annot_table, prev_hyp_state], ) perf_btn.click( lambda hyp, df: compute_performance(hyp, df)[0], inputs=[hyp_selector, annot_table], outputs=[perf_output], ) apply_btn.click( apply_optimal_threshold, inputs=[hyp_selector, annot_table], outputs=[perf_output, cross_table_output], ) export_btn.click( export_xlsx, inputs=[hyp_selector, annot_table], outputs=[export_file], ) if __name__ == "__main__": app.launch(ssr_mode=False)