Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Gradio app for NLI zero-shot classification of survey free-text responses. | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import pipeline | |
| from sklearn.metrics import precision_recall_fscore_support, precision_recall_curve | |
| import tempfile | |
| import os | |
| import traceback | |
| TEMPLATE = "Ce texte exprime {}." | |
| MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" | |
| print("Chargement du modèle NLI...") | |
| import torch | |
| _device = 0 if torch.cuda.is_available() else -1 | |
| classifier = pipeline("zero-shot-classification", model=MODEL_NAME, device=_device) | |
| print("Modèle chargé.") | |
| state = {} | |
| def load_file(file): | |
| if file is None: | |
| raise gr.Error("Aucun fichier uploadé.") | |
| path = file if isinstance(file, str) else file.name | |
| if path.endswith(".csv"): | |
| return pd.read_csv(path) | |
| return pd.read_excel(path) | |
| def on_upload(file): | |
| try: | |
| df = load_file(file) | |
| cols = list(df.columns) | |
| state["raw_df"] = df | |
| print(f"Fichier chargé: {len(df)} lignes, colonnes: {cols}") | |
| return ( | |
| gr.CheckboxGroup(choices=cols, value=[]), | |
| gr.Dropdown(choices=["(aucune)"] + cols, value="(aucune)"), | |
| f"{len(df)} lignes, {len(cols)} colonnes chargées.", | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise gr.Error(f"Erreur de chargement: {e}") | |
| def prepare_texts(df, text_cols): | |
| if len(text_cols) == 1: | |
| out = df[[text_cols[0]]].copy() | |
| out.columns = ["text"] | |
| out["source"] = text_cols[0] | |
| else: | |
| out = df[text_cols].melt(var_name="source", value_name="text") | |
| out["text"] = out["text"].fillna("").astype(str).str.strip() | |
| out = out[out["text"] != ""].reset_index(drop=True) | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # Annotation helpers | |
| # --------------------------------------------------------------------------- | |
| def _save_annotations_from_df(hyp, annot_df): | |
| """Write annotations from the displayed dataframe back to state. Returns count.""" | |
| sample = state.get("sample") | |
| if sample is None or hyp is None: | |
| return 0 | |
| annot_col = f"annot__{hyp}" | |
| score_col = f"score__{hyp}" | |
| sorted_idx = sample[score_col].sort_values(ascending=False).index | |
| if annot_df is None or len(annot_df) == 0: | |
| return 0 | |
| annotations = annot_df["Annotation"].values | |
| for i, idx in enumerate(sorted_idx): | |
| if i >= len(annotations): | |
| break | |
| val = annotations[i] | |
| val_str = str(val).strip() if pd.notna(val) else "" | |
| if val_str in ("0", "1", "0.0", "1.0"): | |
| sample.at[idx, annot_col] = int(float(val_str)) | |
| else: | |
| sample.at[idx, annot_col] = np.nan | |
| state["sample"] = sample | |
| return int(sample[annot_col].notna().sum()) | |
| def get_annotation_view(hyp): | |
| sample = state.get("sample") | |
| if sample is None or not hyp: | |
| return pd.DataFrame(columns=["Texte", "Score", "Prédiction", "Annotation"]) | |
| view = sample[["text", f"score__{hyp}", f"pred__{hyp}", f"annot__{hyp}"]].copy() | |
| view.columns = ["Texte", "Score", "Prédiction", "Annotation"] | |
| view = view.sort_values("Score", ascending=False).reset_index(drop=True) | |
| # Replace NaN with empty string for cleaner display | |
| view["Annotation"] = view["Annotation"].apply( | |
| lambda x: int(x) if pd.notna(x) else "" | |
| ) | |
| return view | |
| # --------------------------------------------------------------------------- | |
| # Classification | |
| # --------------------------------------------------------------------------- | |
| def run_classification(file, text_cols, group_col, hypotheses_text, threshold, | |
| progress=gr.Progress()): | |
| print(f"run_classification called: text_cols={text_cols}, threshold={threshold}") | |
| try: | |
| if not text_cols: | |
| raise gr.Error("Sélectionne au moins une colonne de texte.") | |
| hypotheses_raw = [h.strip() for h in hypotheses_text.strip().split("\n") if h.strip()] | |
| if not hypotheses_raw: | |
| raise gr.Error("Aucune hypothèse saisie.") | |
| df = state.get("raw_df") | |
| if df is None: | |
| raise gr.Error("Recharge le fichier.") | |
| sample = prepare_texts(df, text_cols) | |
| if group_col and group_col != "(aucune)": | |
| if len(text_cols) == 1: | |
| sample["group"] = df[group_col].values[:len(sample)] | |
| else: | |
| groups = [] | |
| for col in text_cols: | |
| groups.extend(df[group_col].values) | |
| sample["group"] = groups[:len(sample)] | |
| else: | |
| sample["group"] = "Tous" | |
| texts = list(sample["text"]) | |
| n_hyp = len(hypotheses_raw) | |
| print(f"Classifying {len(texts)} texts x {n_hyp} hypotheses") | |
| score_cols = {} | |
| for i, hyp in enumerate(hypotheses_raw): | |
| progress(i / n_hyp, desc=f"Hypothèse {i+1}/{n_hyp}: {hyp[:40]}...") | |
| print(f" -> {hyp}") | |
| results = classifier( | |
| texts, | |
| candidate_labels=[hyp], | |
| hypothesis_template=TEMPLATE, | |
| multi_label=False, | |
| batch_size=32, | |
| ) | |
| scores = np.array([r["scores"][0] for r in results]) | |
| score_cols[hyp] = scores | |
| progress(1.0, desc="Terminé.") | |
| for hyp in hypotheses_raw: | |
| sample[f"score__{hyp}"] = score_cols[hyp] | |
| sample[f"pred__{hyp}"] = (score_cols[hyp] >= threshold).astype(int) | |
| sample[f"annot__{hyp}"] = np.nan | |
| state["sample"] = sample | |
| state["hypotheses"] = hypotheses_raw | |
| state["threshold"] = threshold | |
| cross = build_cross_table(sample, hypotheses_raw, threshold) | |
| state["cross_table"] = cross | |
| annot_df = get_annotation_view(hypotheses_raw[0]) | |
| print(f"Done.\n{cross}") | |
| return ( | |
| gr.Dataframe(value=cross, visible=True), | |
| gr.Radio(choices=hypotheses_raw, value=hypotheses_raw[0]), | |
| annot_df, | |
| gr.Column(visible=True), | |
| "", | |
| ) | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise gr.Error(f"Erreur: {e}") | |
| def build_cross_table(sample, hypotheses, threshold=None): | |
| """Build cross-table using per-hypothesis thresholds if available.""" | |
| groups = sorted(sample["group"].unique()) | |
| rows = [] | |
| default_thr = state.get("threshold", 0.5) | |
| per_thr = state.get("thresholds", {}) | |
| for hyp in hypotheses: | |
| sc = f"score__{hyp}" | |
| t = per_thr.get(hyp, threshold if threshold is not None else default_thr) | |
| row = {"Hypothèse": hyp} | |
| for g in groups: | |
| mask = sample["group"] == g | |
| row[g] = round((sample.loc[mask, sc] >= t).mean() * 100, 1) | |
| row["Tous"] = round((sample[sc] >= t).mean() * 100, 1) | |
| rows.append(row) | |
| return pd.DataFrame(rows) | |
| def on_hypothesis_change(new_hyp, prev_hyp, annot_df): | |
| """Auto-save annotations for previous hypothesis, then load new one.""" | |
| if prev_hyp and annot_df is not None and len(annot_df) > 0: | |
| n = _save_annotations_from_df(prev_hyp, annot_df) | |
| print(f"Auto-saved {n} annotations for '{prev_hyp}'") | |
| return get_annotation_view(new_hyp), new_hyp | |
| def compute_performance(hyp, annot_df): | |
| """Save current annotations first, then compute metrics.""" | |
| if hyp and annot_df is not None and len(annot_df) > 0: | |
| _save_annotations_from_df(hyp, annot_df) | |
| sample = state.get("sample") | |
| if sample is None: | |
| return "Pas de données.", None | |
| annot_col = f"annot__{hyp}" | |
| score_col = f"score__{hyp}" | |
| mask = sample[annot_col].notna() | |
| n = mask.sum() | |
| if n < 5: | |
| return f"Seulement {n} annotations. Minimum 5 pour calculer.", None | |
| y_true = sample.loc[mask, annot_col].astype(int).values | |
| scores = sample.loc[mask, score_col].values | |
| if len(set(y_true)) < 2: | |
| return "Il faut au moins un exemple positif et un négatif.", None | |
| thr = state.get("threshold", 0.5) | |
| per_thr = state.get("thresholds", {}) | |
| current_thr = per_thr.get(hyp, thr) | |
| y_pred = (scores >= current_thr).astype(int) | |
| p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0) | |
| precisions, recalls, thr_pr = precision_recall_curve(y_true, scores) | |
| f1s = 2 * precisions * recalls / (precisions + recalls + 1e-10) | |
| best_idx = np.argmax(f1s) | |
| best_thr = thr_pr[best_idx] if best_idx < len(thr_pr) else current_thr | |
| best_f1 = f1s[best_idx] | |
| best_p = precisions[best_idx] | |
| best_r = recalls[best_idx] | |
| report = ( | |
| f"### '{hyp}' — {n} annotations\n\n" | |
| f"**Seuil actuel ({current_thr:.2f})** : P={p:.2f} R={r:.2f} F1={f1:.2f}\n\n" | |
| f"**Seuil optimal ({best_thr:.2f})** : P={best_p:.2f} R={best_r:.2f} F1={best_f1:.2f}" | |
| ) | |
| return report, round(best_thr, 3) | |
| def apply_optimal_threshold(hyp, annot_df): | |
| report, best_thr = compute_performance(hyp, annot_df) | |
| if best_thr is None: | |
| return report, state.get("cross_table", pd.DataFrame()) | |
| sample = state["sample"] | |
| hypotheses = state["hypotheses"] | |
| if "thresholds" not in state: | |
| state["thresholds"] = {h: state.get("threshold", 0.5) for h in hypotheses} | |
| state["thresholds"][hyp] = best_thr | |
| cross = build_cross_table(sample, hypotheses) | |
| state["cross_table"] = cross | |
| report += f"\n\nTableau croisé recalculé avec seuil {best_thr:.3f} pour cette hypothèse." | |
| return report, cross | |
| def export_xlsx(hyp, annot_df): | |
| # Auto-save before export | |
| if hyp and annot_df is not None and len(annot_df) > 0: | |
| _save_annotations_from_df(hyp, annot_df) | |
| sample = state.get("sample") | |
| if sample is None: | |
| raise gr.Error("Rien à exporter.") | |
| path = os.path.join(tempfile.gettempdir(), "classification_nli.xlsx") | |
| with pd.ExcelWriter(path, engine="openpyxl") as writer: | |
| sample.to_excel(writer, sheet_name="Données", index=False) | |
| cross = state.get("cross_table") | |
| if cross is not None: | |
| cross.to_excel(writer, sheet_name="Tableau croisé", index=False) | |
| thresholds = state.get("thresholds", {}) | |
| if thresholds: | |
| pd.DataFrame([{"Hypothèse": k, "Seuil": v} for k, v in thresholds.items()]).to_excel( | |
| writer, sheet_name="Seuils", index=False) | |
| return path | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Classification NLI") as app: | |
| # Hidden state to track which hypothesis was previously selected | |
| prev_hyp_state = gr.State(value=None) | |
| gr.Markdown("# Classification NLI zero-shot de textes libres") | |
| gr.Markdown( | |
| "Upload un fichier (xlsx/csv), choisis les colonnes de texte, " | |
| "saisis des hypothèses (une par ligne), lance la classification. " | |
| "Annote ensuite pour calibrer les seuils." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Fichier (xlsx ou csv)", file_types=[".xlsx", ".csv"]) | |
| upload_status = gr.Textbox(label="Statut", interactive=False) | |
| text_cols = gr.CheckboxGroup(label="Colonnes de texte", choices=[]) | |
| group_col = gr.Dropdown(label="Colonne de regroupement (optionnel)", | |
| choices=["(aucune)"], value="(aucune)") | |
| with gr.Column(scale=1): | |
| hypotheses_input = gr.Textbox( | |
| label="Hypothèses (une par ligne)", lines=10, | |
| placeholder="un manque de soutien politique\nun manque de moyens humains\n...", | |
| ) | |
| threshold_input = gr.Slider(0.1, 0.95, value=0.5, step=0.05, | |
| label="Seuil de classification") | |
| run_btn = gr.Button("Lancer la classification", variant="primary") | |
| cross_table_output = gr.Dataframe(label="Tableau croisé (% affirmatifs)", visible=False) | |
| with gr.Column(visible=False) as annot_section: | |
| gr.Markdown("## Annotation & Calibration") | |
| gr.Markdown( | |
| "Clique sur une hypothèse, puis annote les textes : " | |
| "mets **1** (affirme) ou **0** (non) dans la colonne Annotation. " | |
| "Les annotations sont sauvegardées automatiquement quand tu changes d'hypothèse." | |
| ) | |
| hyp_selector = gr.Radio(label="Hypothèse", choices=[], interactive=True) | |
| annot_table = gr.Dataframe( | |
| label="Textes classifiés (triés par score décroissant)", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| perf_btn = gr.Button("Calculer performance & seuil optimal", variant="secondary") | |
| apply_btn = gr.Button("Appliquer seuil optimal & recalculer", variant="secondary") | |
| perf_output = gr.Markdown("") | |
| with gr.Row(): | |
| export_btn = gr.Button("Exporter tout (.xlsx)", variant="primary") | |
| export_file = gr.File(label="Fichier exporté") | |
| # --- Events --- | |
| file_input.change(on_upload, inputs=[file_input], | |
| outputs=[text_cols, group_col, upload_status]) | |
| run_btn.click( | |
| run_classification, | |
| inputs=[file_input, text_cols, group_col, hypotheses_input, threshold_input], | |
| outputs=[cross_table_output, hyp_selector, annot_table, annot_section, perf_output], | |
| ) | |
| # When hypothesis changes: auto-save previous annotations, load new view | |
| hyp_selector.change( | |
| on_hypothesis_change, | |
| inputs=[hyp_selector, prev_hyp_state, annot_table], | |
| outputs=[annot_table, prev_hyp_state], | |
| ) | |
| perf_btn.click( | |
| lambda hyp, df: compute_performance(hyp, df)[0], | |
| inputs=[hyp_selector, annot_table], | |
| outputs=[perf_output], | |
| ) | |
| apply_btn.click( | |
| apply_optimal_threshold, | |
| inputs=[hyp_selector, annot_table], | |
| outputs=[perf_output, cross_table_output], | |
| ) | |
| export_btn.click( | |
| export_xlsx, | |
| inputs=[hyp_selector, annot_table], | |
| outputs=[export_file], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(ssr_mode=False) |