fatmata commited on
Commit
6400ad7
·
verified ·
1 Parent(s): bb941ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
7
+ from mtranslate import translate
8
+ from langdetect import detect
9
+ from duckduckgo_search import DDGS
10
+ import re
11
+
12
+ # =============================
13
+ # Nettoyage texte
14
+ # =============================
15
+ def clean_response(text):
16
+ text = re.sub(r'<[^>]+>', '', text)
17
+ text = re.split(r'</(Bot|name|opinion|User|[a-zA-Z]*)>', text)[0]
18
+ text = re.sub(r'^\s*[,.:;-]*', '', text)
19
+ text = re.sub(r'^\s*(Psyche|Therapist|Bot|Assistant|AI):?\s*', '', text)
20
+ text = re.sub(r'\([^)]*\)', '', text)
21
+ text = re.sub(r'\[.*?\]', '', text)
22
+ text = re.sub(r'[:;=8][-~]?[)D(\\/*|]', '', text)
23
+ text = re.sub(r'\s{2,}', ' ', text).strip()
24
+ sentences = re.split(r'(?<=[.!?])\s+', text)
25
+ return " ".join(sentences[:2]).strip()
26
+
27
+ # =============================
28
+ # Charger modèles
29
+ # =============================
30
+ MODEL_PATH = "fatmata/gpt-psybot"
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
32
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
33
+
34
+ BERT_MODEL_NAME = "fatmata/bert_model"
35
+ bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
36
+ bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_NAME)
37
+
38
+ CLASSIFIER_PATH = "fatmata/mini_bert"
39
+ model_c = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_PATH)
40
+ tokenizer_c = AutoTokenizer.from_pretrained(CLASSIFIER_PATH)
41
+
42
+ # =============================
43
+ # Analyse émotion
44
+ # =============================
45
+ analyzer = SentimentIntensityAnalyzer()
46
+ GOEMOTIONS_LABELS = ["admiration","anger","approval","autre","curiosity",
47
+ "disapproval","gratitude","joy","love","neutral","sadness"]
48
+ UNACCEPTABLE_EMOTIONS = {"anger"}
49
+
50
+ def detect_language(text):
51
+ try:
52
+ detected_lang = detect(text)
53
+ return detected_lang if detected_lang in ["fr", "en", "ar"] else "en"
54
+ except:
55
+ return "en"
56
+
57
+ def search_duckduckgo(query, max_results=3):
58
+ try:
59
+ search_results = list(DDGS().text(query, max_results=max_results))
60
+ return [result["body"] for result in search_results if "body" in result] or ["Pas trouvé."]
61
+ except Exception as e:
62
+ return [f"Erreur recherche : {str(e)}"]
63
+
64
+ def generate_response(user_input):
65
+ prompt = f"User: {user_input}\nBot:"
66
+ inputs = tokenizer(prompt, return_tensors="pt")
67
+ output = model.generate(
68
+ input_ids=inputs["input_ids"],
69
+ max_new_tokens=150,
70
+ pad_token_id=tokenizer.eos_token_id,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ do_sample=True,
73
+ temperature=0.7,
74
+ top_k=50,
75
+ top_p=0.9,
76
+ repetition_penalty=1.2
77
+ )
78
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
79
+ return clean_response(generated_text.split("Bot:")[-1].strip())
80
+
81
+ def classify_emotion(text):
82
+ sentiment_scores = analyzer.polarity_scores(text)
83
+ compound = sentiment_scores['compound'] * 100
84
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
85
+ with torch.no_grad():
86
+ logits = bert_model(**inputs).logits
87
+ probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
88
+ top_emotion_index = probs.argmax()
89
+ top_emotion = GOEMOTIONS_LABELS[top_emotion_index]
90
+ return compound, top_emotion in UNACCEPTABLE_EMOTIONS, top_emotion
91
+
92
+ def predict_category(text):
93
+ inputs = tokenizer_c(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
94
+ with torch.no_grad():
95
+ outputs = model_c(**inputs)
96
+ logits = outputs.logits
97
+ return "recherche" if torch.argmax(logits, dim=-1).item() == 1 else "gpt"
98
+
99
+ # =============================
100
+ # Fonction principale
101
+ # =============================
102
+ def classify_and_respond(text):
103
+ steps = []
104
+ original_lang = detect_language(text)
105
+ text_en = translate(text, "en")
106
+
107
+ # Étape 1 : prédiction catégorie
108
+ category = predict_category(text_en)
109
+ steps.append("Catégorie détectée : " + category)
110
+
111
+ if category == "recherche":
112
+ response = search_duckduckgo(text_en)
113
+ final_response = "\n".join([translate(r, original_lang) for r in response])
114
+ steps.append("Résultats DuckDuckGo récupérés")
115
+ return {
116
+ "response": final_response,
117
+ "response_type": "recherche",
118
+ "emotions": None,
119
+ "steps": steps
120
+ }
121
+
122
+ # Étape 2 : analyse émotion
123
+ compound, is_unacceptable, emotion = classify_emotion(text_en)
124
+ steps.append(f"Émotion détectée : {emotion} (score={compound:.2f})")
125
+
126
+ if is_unacceptable and abs(compound) > 50:
127
+ final_response = translate("Je ressens beaucoup de tension dans votre message.", original_lang)
128
+ steps.append("Réponse émotion inacceptable envoyée")
129
+ return {
130
+ "response": final_response,
131
+ "response_type": "non acceptable",
132
+ "emotions": emotion,
133
+ "steps": steps
134
+ }
135
+
136
+ # Étape 3 : génération GPT
137
+ gpt_response = generate_response(text_en)
138
+ final_response = translate(gpt_response, original_lang)
139
+ steps.append("Réponse GPT générée et traduite")
140
+
141
+ return {
142
+ "response": final_response,
143
+ "response_type": "gpt",
144
+ "emotions": emotion,
145
+ "steps": steps
146
+ }
147
+
148
+ # =============================
149
+ # API FastAPI
150
+ # =============================
151
+ app = FastAPI()
152
+
153
+ class RequestBody(BaseModel):
154
+ text: str
155
+
156
+ @app.post("/predict")
157
+ async def predict_api(body: RequestBody):
158
+ return classify_and_respond(body.text)