end / app.py
fatmata's picture
Update app.py
a6d773a verified
import os
# =============================
# Fix cache permissions pour HF
# =============================
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_HUB_CACHE"] = "/tmp/hf"
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from mtranslate import translate
from langdetect import detect
from duckduckgo_search import DDGS
import re
# =============================
# Nettoyage texte
# =============================
def clean_response(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.split(r'</(Bot|name|opinion|User|[a-zA-Z]*)>', text)[0]
text = re.sub(r'^\s*[,.:;-]*', '', text)
text = re.sub(r'^\s*(Psyche|Therapist|Bot|Assistant|AI):?\s*', '', text)
text = re.sub(r'\([^)]*\)', '', text)
text = re.sub(r'\[.*?\]', '', text)
text = re.sub(r'[:;=8][-~]?[)D(\\/*|]', '', text)
text = re.sub(r'\s{2,}', ' ', text).strip()
sentences = re.split(r'(?<=[.!?])\s+', text)
return " ".join(sentences[:2]).strip()
# =============================
# Charger modèles
# =============================
MODEL_PATH = "fatmata/gpt-psybot"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
BERT_MODEL_NAME = "fatmata/bert_model"
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_NAME)
CLASSIFIER_PATH = "fatmata/mini_bert"
model_c = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_PATH)
tokenizer_c = AutoTokenizer.from_pretrained(CLASSIFIER_PATH)
# =============================
# Analyse émotion
# =============================
analyzer = SentimentIntensityAnalyzer()
GOEMOTIONS_LABELS = ["admiration","anger","approval","autre","curiosity",
"disapproval","gratitude","joy","love","neutral","sadness"]
UNACCEPTABLE_EMOTIONS = {"anger"}
def detect_language(text):
try:
detected_lang = detect(text)
return detected_lang if detected_lang in ["fr", "en", "ar"] else "en"
except:
return "en"
def search_duckduckgo(query, max_results=3):
try:
search_results = list(DDGS().text(query, max_results=max_results))
return [result["body"] for result in search_results if "body" in result] or ["Pas trouvé."]
except Exception as e:
return [f"Erreur recherche : {str(e)}"]
def generate_response(user_input):
prompt = f"User: {user_input}\nBot:"
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(
input_ids=inputs["input_ids"],
max_new_tokens=150,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return clean_response(generated_text.split("Bot:")[-1].strip())
def classify_emotion(text):
sentiment_scores = analyzer.polarity_scores(text)
compound = sentiment_scores['compound'] * 100
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
logits = bert_model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
top_emotion_index = probs.argmax()
top_emotion = GOEMOTIONS_LABELS[top_emotion_index]
return compound, top_emotion in UNACCEPTABLE_EMOTIONS, top_emotion
def predict_category(text):
inputs = tokenizer_c(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model_c(**inputs)
logits = outputs.logits
return "recherche" if torch.argmax(logits, dim=-1).item() == 1 else "gpt"
# =============================
# Fonction principale
# =============================
def classify_and_respond(text):
steps = []
original_lang = detect_language(text)
text_en = translate(text, "en")
# Étape 1 : prédiction catégorie
category = predict_category(text_en)
steps.append("Catégorie détectée : " + category)
if category == "recherche":
response = search_duckduckgo(text_en)
final_response = "\n".join([translate(r, original_lang) for r in response])
steps.append("Résultats DuckDuckGo récupérés")
return {
"response": final_response,
"response_type": "recherche",
"emotions": None,
"steps": steps
}
# Étape 2 : analyse émotion
compound, is_unacceptable, emotion = classify_emotion(text_en)
steps.append(f"Émotion détectée : {emotion} (score={compound:.2f})")
if is_unacceptable and abs(compound) > 50:
final_response = translate("Je ressens beaucoup de tension dans votre message.", original_lang)
steps.append("Réponse émotion inacceptable envoyée")
return {
"response": final_response,
"response_type": "non acceptable",
"emotions": emotion,
"steps": steps
}
# Étape 3 : génération GPT
gpt_response = generate_response(text_en)
final_response = translate(gpt_response, original_lang)
steps.append("Réponse GPT générée et traduite")
return {
"response": final_response,
"response_type": "gpt",
"emotions": emotion,
"steps": steps
}
# =============================
# API FastAPI
# =============================
app = FastAPI()
class RequestBody(BaseModel):
text: str
@app.post("/predict")
async def predict_api(body: RequestBody):
return classify_and_respond(body.text)
# =============================
# Lancement local (uvicorn)
# =============================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)