BenTouss's picture
Update app.py
68d7edf verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
APP_NAME = "EuroChef"
tokenizer = AutoTokenizer.from_pretrained("BenTouss/mdeberta-eurochef")
model = AutoModelForSequenceClassification.from_pretrained("BenTouss/mdeberta-eurochef")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def predict(text: str, threshold: float = 0.6, top_k: int = 8, only_above: bool = True):
text = (text or "").strip()
if not text:
return "_Paste a message on the left to start._", []
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits)[0].detach().cpu()
items = []
for idx, prob in enumerate(probs):
score = float(prob)
label = model.config.id2label[idx]
if (not only_above) or (score >= threshold):
items.append((label, score))
items.sort(key=lambda x: x[1], reverse=True)
items = items[: max(1, int(top_k))]
rows = [[lbl, float(f"{sc:.3f}")] for lbl, sc in items]
if rows:
best_lbl, best_sc = rows[0][0], rows[0][1]
summary = (
f"**Top label:** `{best_lbl}` • **score:** `{best_sc}` \n"
f"**Results:** {len(rows)} • **threshold:** `{threshold:.2f}`"
)
else:
summary = f"_No label (threshold `{threshold:.2f}`). Try lowering it._"
return summary, rows
CSS = """
#title { margin-bottom: 0.25rem; }
#subtitle { margin-top: 0; opacity: 0.8; }
.footer { opacity: 0.7; font-size: 0.85rem; text-align: center; margin-top: 0.75rem; }
/* Force a nicer dataframe area without using height= */
#pred_table { min-height: 320px; }
"""
with gr.Blocks() as demo:
gr.Markdown(f"# 🍳 {APP_NAME}", elem_id="title")
gr.Markdown("Customer support message → labels + scores.", elem_id="subtitle")
with gr.Row():
with gr.Column(scale=6):
text = gr.Textbox(
label="Customer support message",
placeholder="Ex: Bonjour, je n’arrive pas à lancer les vidéos…",
lines=10,
)
with gr.Row():
threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.01, label="Threshold")
top_k = gr.Slider(1, 20, value=8, step=1, label="Top-K")
only_above = gr.Checkbox(value=True, label="Only ≥ threshold")
with gr.Row():
run = gr.Button("Analyze", variant="primary")
clear = gr.ClearButton(value="Clear")
gr.Examples(
examples=[
# FR
"Bonjour,\nJe n’arrive pas à lancer les vidéos depuis hier soir : écran noir et chargement infini. "
"Je suis Premium (paiement OK) mais certaines recettes restent verrouillées. Pouvez-vous vérifier mon compte ?\nMerci !",
# EN
"Hi,\nSince yesterday evening I can't play any videos: the screen stays black and keeps buffering. "
"I'm a Premium subscriber (payment went through), but some recipes are still locked. "
"Could you please check my account?\nThanks!",
# DE
"Hallo,\nseit gestern Abend kann ich keine Videos mehr abspielen: Der Bildschirm bleibt schwarz und es lädt endlos. "
"Ich habe ein Premium-Abo (Zahlung ist erfolgt), aber einige Rezepte sind weiterhin gesperrt. "
"Können Sie bitte mein Konto überprüfen?\nVielen Dank!"
],
inputs=[text],
label="Examples (FR / EN / DE)",
)
with gr.Column(scale=6):
summary = gr.Markdown(label="Summary")
table = gr.Dataframe(
headers=["label", "score"],
datatype=["str", "number"],
label="Predictions",
wrap=True,
interactive=False,
elem_id="pred_table",
)
gr.Markdown(f"<div class='footer'>Made with ❤️ by Ben • {APP_NAME}</div>")
run.click(fn=predict, inputs=[text, threshold, top_k, only_above], outputs=[summary, table])
clear.add([text, summary, table])
demo.launch(theme=gr.themes.Soft(), css=CSS)