Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
|
| 4 |
import torch
|
|
@@ -88,29 +87,60 @@ def predict(text):
|
|
| 88 |
with torch.no_grad():
|
| 89 |
outputs = model_c(**inputs)
|
| 90 |
logits = outputs.logits
|
| 91 |
-
return "recherche" if torch.argmax(logits, dim=-1).item() == 1 else "
|
| 92 |
|
|
|
|
| 93 |
def classify_and_respond(text):
|
|
|
|
| 94 |
original_lang = detect_language(text)
|
| 95 |
text_en = translate(text, "en")
|
| 96 |
|
|
|
|
| 97 |
category = predict(text_en)
|
|
|
|
|
|
|
| 98 |
if category == "recherche":
|
| 99 |
response = search_duckduckgo(text_en)
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
compound, is_unacceptable, emotion = classify_emotion(text_en)
|
| 103 |
-
|
| 104 |
-
return translate("Je ressens beaucoup de tension dans votre message.", original_lang)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
gpt_response = generate_response(text_en)
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# === Interface Gradio ===
|
| 110 |
iface = gr.Interface(
|
| 111 |
fn=classify_and_respond,
|
| 112 |
inputs=gr.Textbox(lines=2, placeholder="Écris ton message..."),
|
| 113 |
-
outputs="
|
| 114 |
title="PsyBot",
|
| 115 |
description="Chatbot psychologue multilingue basé sur GPT + BERT + MiniBERT"
|
| 116 |
)
|
|
@@ -120,4 +150,4 @@ iface.launch(
|
|
| 120 |
server_port=7860, # port par défaut HF Spaces
|
| 121 |
share=False, # pas besoin de lien ngrok
|
| 122 |
show_api=True # expose /run/predict
|
| 123 |
-
)
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
|
| 3 |
import torch
|
|
|
|
| 87 |
with torch.no_grad():
|
| 88 |
outputs = model_c(**inputs)
|
| 89 |
logits = outputs.logits
|
| 90 |
+
return "recherche" if torch.argmax(logits, dim=-1).item() == 1 else "gpt"
|
| 91 |
|
| 92 |
+
# === Fonction principale ===
|
| 93 |
def classify_and_respond(text):
|
| 94 |
+
steps = []
|
| 95 |
original_lang = detect_language(text)
|
| 96 |
text_en = translate(text, "en")
|
| 97 |
|
| 98 |
+
# Étape 1 : prédiction catégorie
|
| 99 |
category = predict(text_en)
|
| 100 |
+
steps.append("Catégorie détectée : " + category)
|
| 101 |
+
|
| 102 |
if category == "recherche":
|
| 103 |
response = search_duckduckgo(text_en)
|
| 104 |
+
final_response = "\n".join([translate(r, original_lang) for r in response])
|
| 105 |
+
steps.append("Résultats DuckDuckGo récupérés")
|
| 106 |
+
return {
|
| 107 |
+
"response": final_response,
|
| 108 |
+
"response_type": "recherche",
|
| 109 |
+
"emotions": None,
|
| 110 |
+
"steps": steps
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Étape 2 : analyse émotion
|
| 114 |
compound, is_unacceptable, emotion = classify_emotion(text_en)
|
| 115 |
+
steps.append(f"Émotion détectée : {emotion} (score={compound:.2f})")
|
|
|
|
| 116 |
|
| 117 |
+
if is_unacceptable and abs(compound) > 50:
|
| 118 |
+
final_response = translate("Je ressens beaucoup de tension dans votre message.", original_lang)
|
| 119 |
+
steps.append("Réponse émotion inacceptable envoyée")
|
| 120 |
+
return {
|
| 121 |
+
"response": final_response,
|
| 122 |
+
"response_type": "non acceptable",
|
| 123 |
+
"emotions": emotion,
|
| 124 |
+
"steps": steps
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
# Étape 3 : génération GPT
|
| 128 |
gpt_response = generate_response(text_en)
|
| 129 |
+
final_response = translate(gpt_response, original_lang)
|
| 130 |
+
steps.append("Réponse GPT générée et traduite")
|
| 131 |
+
|
| 132 |
+
return {
|
| 133 |
+
"response": final_response,
|
| 134 |
+
"response_type": "gpt",
|
| 135 |
+
"emotions": emotion,
|
| 136 |
+
"steps": steps
|
| 137 |
+
}
|
| 138 |
|
| 139 |
# === Interface Gradio ===
|
| 140 |
iface = gr.Interface(
|
| 141 |
fn=classify_and_respond,
|
| 142 |
inputs=gr.Textbox(lines=2, placeholder="Écris ton message..."),
|
| 143 |
+
outputs="json",
|
| 144 |
title="PsyBot",
|
| 145 |
description="Chatbot psychologue multilingue basé sur GPT + BERT + MiniBERT"
|
| 146 |
)
|
|
|
|
| 150 |
server_port=7860, # port par défaut HF Spaces
|
| 151 |
share=False, # pas besoin de lien ngrok
|
| 152 |
show_api=True # expose /run/predict
|
| 153 |
+
)
|