File size: 14,268 Bytes
01abdec
 
 
 
 
 
 
 
 
 
 
 
8e7865e
01abdec
 
 
 
8e7865e
 
 
 
 
01abdec
8e7865e
01abdec
 
8e7865e
01abdec
 
8e7865e
01abdec
 
 
 
 
 
 
 
 
 
8e7865e
01abdec
8e7865e
 
 
01abdec
 
8e7865e
01abdec
 
 
8e7865e
01abdec
 
 
 
 
 
 
 
 
 
 
8e7865e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
01abdec
8e7865e
 
 
 
01abdec
8e7865e
 
 
01abdec
8e7865e
 
 
01abdec
8e7865e
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
01abdec
 
8e7865e
 
01abdec
 
8e7865e
 
01abdec
8e7865e
 
01abdec
 
 
8e7865e
 
01abdec
 
 
 
8e7865e
 
 
 
 
 
01abdec
 
8e7865e
 
 
 
01abdec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e7865e
 
 
 
01abdec
8e7865e
01abdec
 
8e7865e
01abdec
 
 
 
 
 
8e7865e
01abdec
 
 
 
 
8e7865e
 
01abdec
 
 
 
 
 
 
 
8e7865e
01abdec
 
 
 
 
8e7865e
 
 
 
 
01abdec
 
 
 
 
 
 
 
 
 
 
8e7865e
 
01abdec
 
 
 
 
 
8e7865e
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
01abdec
8e7865e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01abdec
 
8e7865e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#!/usr/bin/env python
"""
Gradio app for NLI zero-shot classification of survey free-text responses.
"""

import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline
from sklearn.metrics import precision_recall_fscore_support, precision_recall_curve
import tempfile
import os
import traceback

TEMPLATE = "Ce texte exprime {}."
MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"

print("Chargement du modèle NLI...")
import torch
_device = 0 if torch.cuda.is_available() else -1
classifier = pipeline("zero-shot-classification", model=MODEL_NAME, device=_device)
print("Modèle chargé.")

state = {}


def load_file(file):
    if file is None:
        raise gr.Error("Aucun fichier uploadé.")
    path = file if isinstance(file, str) else file.name
    if path.endswith(".csv"):
        return pd.read_csv(path)
    return pd.read_excel(path)


def on_upload(file):
    try:
        df = load_file(file)
        cols = list(df.columns)
        state["raw_df"] = df
        print(f"Fichier chargé: {len(df)} lignes, colonnes: {cols}")
        return (
            gr.CheckboxGroup(choices=cols, value=[]),
            gr.Dropdown(choices=["(aucune)"] + cols, value="(aucune)"),
            f"{len(df)} lignes, {len(cols)} colonnes chargées.",
        )
    except Exception as e:
        traceback.print_exc()
        raise gr.Error(f"Erreur de chargement: {e}")


def prepare_texts(df, text_cols):
    if len(text_cols) == 1:
        out = df[[text_cols[0]]].copy()
        out.columns = ["text"]
        out["source"] = text_cols[0]
    else:
        out = df[text_cols].melt(var_name="source", value_name="text")
    out["text"] = out["text"].fillna("").astype(str).str.strip()
    out = out[out["text"] != ""].reset_index(drop=True)
    return out


# ---------------------------------------------------------------------------
# Annotation helpers
# ---------------------------------------------------------------------------
def _save_annotations_from_df(hyp, annot_df):
    """Write annotations from the displayed dataframe back to state. Returns count."""
    sample = state.get("sample")
    if sample is None or hyp is None:
        return 0
    annot_col = f"annot__{hyp}"
    score_col = f"score__{hyp}"
    sorted_idx = sample[score_col].sort_values(ascending=False).index

    if annot_df is None or len(annot_df) == 0:
        return 0

    annotations = annot_df["Annotation"].values
    for i, idx in enumerate(sorted_idx):
        if i >= len(annotations):
            break
        val = annotations[i]
        val_str = str(val).strip() if pd.notna(val) else ""
        if val_str in ("0", "1", "0.0", "1.0"):
            sample.at[idx, annot_col] = int(float(val_str))
        else:
            sample.at[idx, annot_col] = np.nan

    state["sample"] = sample
    return int(sample[annot_col].notna().sum())


def get_annotation_view(hyp):
    sample = state.get("sample")
    if sample is None or not hyp:
        return pd.DataFrame(columns=["Texte", "Score", "Prédiction", "Annotation"])
    view = sample[["text", f"score__{hyp}", f"pred__{hyp}", f"annot__{hyp}"]].copy()
    view.columns = ["Texte", "Score", "Prédiction", "Annotation"]
    view = view.sort_values("Score", ascending=False).reset_index(drop=True)
    # Replace NaN with empty string for cleaner display
    view["Annotation"] = view["Annotation"].apply(
        lambda x: int(x) if pd.notna(x) else ""
    )
    return view


# ---------------------------------------------------------------------------
# Classification
# ---------------------------------------------------------------------------
def run_classification(file, text_cols, group_col, hypotheses_text, threshold,
                       progress=gr.Progress()):
    print(f"run_classification called: text_cols={text_cols}, threshold={threshold}")
    try:
        if not text_cols:
            raise gr.Error("Sélectionne au moins une colonne de texte.")
        hypotheses_raw = [h.strip() for h in hypotheses_text.strip().split("\n") if h.strip()]
        if not hypotheses_raw:
            raise gr.Error("Aucune hypothèse saisie.")

        df = state.get("raw_df")
        if df is None:
            raise gr.Error("Recharge le fichier.")

        sample = prepare_texts(df, text_cols)

        if group_col and group_col != "(aucune)":
            if len(text_cols) == 1:
                sample["group"] = df[group_col].values[:len(sample)]
            else:
                groups = []
                for col in text_cols:
                    groups.extend(df[group_col].values)
                sample["group"] = groups[:len(sample)]
        else:
            sample["group"] = "Tous"

        texts = list(sample["text"])
        n_hyp = len(hypotheses_raw)
        print(f"Classifying {len(texts)} texts x {n_hyp} hypotheses")

        score_cols = {}
        for i, hyp in enumerate(hypotheses_raw):
            progress(i / n_hyp, desc=f"Hypothèse {i+1}/{n_hyp}: {hyp[:40]}...")
            print(f"  -> {hyp}")
            results = classifier(
                texts,
                candidate_labels=[hyp],
                hypothesis_template=TEMPLATE,
                multi_label=False,
                batch_size=32,
            )
            scores = np.array([r["scores"][0] for r in results])
            score_cols[hyp] = scores

        progress(1.0, desc="Terminé.")

        for hyp in hypotheses_raw:
            sample[f"score__{hyp}"] = score_cols[hyp]
            sample[f"pred__{hyp}"] = (score_cols[hyp] >= threshold).astype(int)
            sample[f"annot__{hyp}"] = np.nan

        state["sample"] = sample
        state["hypotheses"] = hypotheses_raw
        state["threshold"] = threshold

        cross = build_cross_table(sample, hypotheses_raw, threshold)
        state["cross_table"] = cross
        annot_df = get_annotation_view(hypotheses_raw[0])

        print(f"Done.\n{cross}")

        return (
            gr.Dataframe(value=cross, visible=True),
            gr.Radio(choices=hypotheses_raw, value=hypotheses_raw[0]),
            annot_df,
            gr.Column(visible=True),
            "",
        )
    except gr.Error:
        raise
    except Exception as e:
        traceback.print_exc()
        raise gr.Error(f"Erreur: {e}")


def build_cross_table(sample, hypotheses, threshold=None):
    """Build cross-table using per-hypothesis thresholds if available."""
    groups = sorted(sample["group"].unique())
    rows = []
    default_thr = state.get("threshold", 0.5)
    per_thr = state.get("thresholds", {})
    for hyp in hypotheses:
        sc = f"score__{hyp}"
        t = per_thr.get(hyp, threshold if threshold is not None else default_thr)
        row = {"Hypothèse": hyp}
        for g in groups:
            mask = sample["group"] == g
            row[g] = round((sample.loc[mask, sc] >= t).mean() * 100, 1)
        row["Tous"] = round((sample[sc] >= t).mean() * 100, 1)
        rows.append(row)
    return pd.DataFrame(rows)


def on_hypothesis_change(new_hyp, prev_hyp, annot_df):
    """Auto-save annotations for previous hypothesis, then load new one."""
    if prev_hyp and annot_df is not None and len(annot_df) > 0:
        n = _save_annotations_from_df(prev_hyp, annot_df)
        print(f"Auto-saved {n} annotations for '{prev_hyp}'")
    return get_annotation_view(new_hyp), new_hyp


def compute_performance(hyp, annot_df):
    """Save current annotations first, then compute metrics."""
    if hyp and annot_df is not None and len(annot_df) > 0:
        _save_annotations_from_df(hyp, annot_df)

    sample = state.get("sample")
    if sample is None:
        return "Pas de données.", None
    annot_col = f"annot__{hyp}"
    score_col = f"score__{hyp}"
    mask = sample[annot_col].notna()
    n = mask.sum()
    if n < 5:
        return f"Seulement {n} annotations. Minimum 5 pour calculer.", None
    y_true = sample.loc[mask, annot_col].astype(int).values
    scores = sample.loc[mask, score_col].values
    if len(set(y_true)) < 2:
        return "Il faut au moins un exemple positif et un négatif.", None

    thr = state.get("threshold", 0.5)
    per_thr = state.get("thresholds", {})
    current_thr = per_thr.get(hyp, thr)
    y_pred = (scores >= current_thr).astype(int)
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)

    precisions, recalls, thr_pr = precision_recall_curve(y_true, scores)
    f1s = 2 * precisions * recalls / (precisions + recalls + 1e-10)
    best_idx = np.argmax(f1s)
    best_thr = thr_pr[best_idx] if best_idx < len(thr_pr) else current_thr
    best_f1 = f1s[best_idx]
    best_p = precisions[best_idx]
    best_r = recalls[best_idx]

    report = (
        f"### '{hyp}' — {n} annotations\n\n"
        f"**Seuil actuel ({current_thr:.2f})** : P={p:.2f}  R={r:.2f}  F1={f1:.2f}\n\n"
        f"**Seuil optimal ({best_thr:.2f})** : P={best_p:.2f}  R={best_r:.2f}  F1={best_f1:.2f}"
    )
    return report, round(best_thr, 3)


def apply_optimal_threshold(hyp, annot_df):
    report, best_thr = compute_performance(hyp, annot_df)
    if best_thr is None:
        return report, state.get("cross_table", pd.DataFrame())
    sample = state["sample"]
    hypotheses = state["hypotheses"]
    if "thresholds" not in state:
        state["thresholds"] = {h: state.get("threshold", 0.5) for h in hypotheses}
    state["thresholds"][hyp] = best_thr

    cross = build_cross_table(sample, hypotheses)
    state["cross_table"] = cross
    report += f"\n\nTableau croisé recalculé avec seuil {best_thr:.3f} pour cette hypothèse."
    return report, cross


def export_xlsx(hyp, annot_df):
    # Auto-save before export
    if hyp and annot_df is not None and len(annot_df) > 0:
        _save_annotations_from_df(hyp, annot_df)

    sample = state.get("sample")
    if sample is None:
        raise gr.Error("Rien à exporter.")
    path = os.path.join(tempfile.gettempdir(), "classification_nli.xlsx")
    with pd.ExcelWriter(path, engine="openpyxl") as writer:
        sample.to_excel(writer, sheet_name="Données", index=False)
        cross = state.get("cross_table")
        if cross is not None:
            cross.to_excel(writer, sheet_name="Tableau croisé", index=False)
        thresholds = state.get("thresholds", {})
        if thresholds:
            pd.DataFrame([{"Hypothèse": k, "Seuil": v} for k, v in thresholds.items()]).to_excel(
                writer, sheet_name="Seuils", index=False)
    return path


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Classification NLI") as app:
    # Hidden state to track which hypothesis was previously selected
    prev_hyp_state = gr.State(value=None)

    gr.Markdown("# Classification NLI zero-shot de textes libres")
    gr.Markdown(
        "Upload un fichier (xlsx/csv), choisis les colonnes de texte, "
        "saisis des hypothèses (une par ligne), lance la classification. "
        "Annote ensuite pour calibrer les seuils."
    )

    with gr.Row():
        with gr.Column(scale=1):
            file_input = gr.File(label="Fichier (xlsx ou csv)", file_types=[".xlsx", ".csv"])
            upload_status = gr.Textbox(label="Statut", interactive=False)
            text_cols = gr.CheckboxGroup(label="Colonnes de texte", choices=[])
            group_col = gr.Dropdown(label="Colonne de regroupement (optionnel)",
                                    choices=["(aucune)"], value="(aucune)")
        with gr.Column(scale=1):
            hypotheses_input = gr.Textbox(
                label="Hypothèses (une par ligne)", lines=10,
                placeholder="un manque de soutien politique\nun manque de moyens humains\n...",
            )
            threshold_input = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
                                        label="Seuil de classification")
            run_btn = gr.Button("Lancer la classification", variant="primary")

    cross_table_output = gr.Dataframe(label="Tableau croisé (% affirmatifs)", visible=False)

    with gr.Column(visible=False) as annot_section:
        gr.Markdown("## Annotation & Calibration")
        gr.Markdown(
            "Clique sur une hypothèse, puis annote les textes : "
            "mets **1** (affirme) ou **0** (non) dans la colonne Annotation. "
            "Les annotations sont sauvegardées automatiquement quand tu changes d'hypothèse."
        )
        hyp_selector = gr.Radio(label="Hypothèse", choices=[], interactive=True)
        annot_table = gr.Dataframe(
            label="Textes classifiés (triés par score décroissant)",
            interactive=True,
        )
        with gr.Row():
            perf_btn = gr.Button("Calculer performance & seuil optimal", variant="secondary")
            apply_btn = gr.Button("Appliquer seuil optimal & recalculer", variant="secondary")
        perf_output = gr.Markdown("")
        with gr.Row():
            export_btn = gr.Button("Exporter tout (.xlsx)", variant="primary")
            export_file = gr.File(label="Fichier exporté")

    # --- Events ---
    file_input.change(on_upload, inputs=[file_input],
                      outputs=[text_cols, group_col, upload_status])

    run_btn.click(
        run_classification,
        inputs=[file_input, text_cols, group_col, hypotheses_input, threshold_input],
        outputs=[cross_table_output, hyp_selector, annot_table, annot_section, perf_output],
    )

    # When hypothesis changes: auto-save previous annotations, load new view
    hyp_selector.change(
        on_hypothesis_change,
        inputs=[hyp_selector, prev_hyp_state, annot_table],
        outputs=[annot_table, prev_hyp_state],
    )

    perf_btn.click(
        lambda hyp, df: compute_performance(hyp, df)[0],
        inputs=[hyp_selector, annot_table],
        outputs=[perf_output],
    )
    apply_btn.click(
        apply_optimal_threshold,
        inputs=[hyp_selector, annot_table],
        outputs=[perf_output, cross_table_output],
    )
    export_btn.click(
        export_xlsx,
        inputs=[hyp_selector, annot_table],
        outputs=[export_file],
    )

if __name__ == "__main__":
    app.launch(ssr_mode=False)