|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" |
|
|
os.environ["HF_HUB_CACHE"] = "/tmp/hf" |
|
|
|
|
|
|
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer |
|
|
from mtranslate import translate |
|
|
from langdetect import detect |
|
|
from duckduckgo_search import DDGS |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_response(text): |
|
|
text = re.sub(r'<[^>]+>', '', text) |
|
|
text = re.split(r'</(Bot|name|opinion|User|[a-zA-Z]*)>', text)[0] |
|
|
text = re.sub(r'^\s*[,.:;-]*', '', text) |
|
|
text = re.sub(r'^\s*(Psyche|Therapist|Bot|Assistant|AI):?\s*', '', text) |
|
|
text = re.sub(r'\([^)]*\)', '', text) |
|
|
text = re.sub(r'\[.*?\]', '', text) |
|
|
text = re.sub(r'[:;=8][-~]?[)D(\\/*|]', '', text) |
|
|
text = re.sub(r'\s{2,}', ' ', text).strip() |
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
return " ".join(sentences[:2]).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "fatmata/gpt-psybot" |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) |
|
|
|
|
|
BERT_MODEL_NAME = "fatmata/bert_model" |
|
|
bert_tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME) |
|
|
bert_model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL_NAME) |
|
|
|
|
|
CLASSIFIER_PATH = "fatmata/mini_bert" |
|
|
model_c = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_PATH) |
|
|
tokenizer_c = AutoTokenizer.from_pretrained(CLASSIFIER_PATH) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
analyzer = SentimentIntensityAnalyzer() |
|
|
GOEMOTIONS_LABELS = ["admiration","anger","approval","autre","curiosity", |
|
|
"disapproval","gratitude","joy","love","neutral","sadness"] |
|
|
UNACCEPTABLE_EMOTIONS = {"anger"} |
|
|
|
|
|
def detect_language(text): |
|
|
try: |
|
|
detected_lang = detect(text) |
|
|
return detected_lang if detected_lang in ["fr", "en", "ar"] else "en" |
|
|
except: |
|
|
return "en" |
|
|
|
|
|
def search_duckduckgo(query, max_results=3): |
|
|
try: |
|
|
search_results = list(DDGS().text(query, max_results=max_results)) |
|
|
return [result["body"] for result in search_results if "body" in result] or ["Pas trouvé."] |
|
|
except Exception as e: |
|
|
return [f"Erreur recherche : {str(e)}"] |
|
|
|
|
|
def generate_response(user_input): |
|
|
prompt = f"User: {user_input}\nBot:" |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
output = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
max_new_tokens=150, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
return clean_response(generated_text.split("Bot:")[-1].strip()) |
|
|
|
|
|
def classify_emotion(text): |
|
|
sentiment_scores = analyzer.polarity_scores(text) |
|
|
compound = sentiment_scores['compound'] * 100 |
|
|
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) |
|
|
with torch.no_grad(): |
|
|
logits = bert_model(**inputs).logits |
|
|
probs = F.softmax(logits, dim=-1).squeeze().cpu().numpy() |
|
|
top_emotion_index = probs.argmax() |
|
|
top_emotion = GOEMOTIONS_LABELS[top_emotion_index] |
|
|
return compound, top_emotion in UNACCEPTABLE_EMOTIONS, top_emotion |
|
|
|
|
|
def predict_category(text): |
|
|
inputs = tokenizer_c(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
outputs = model_c(**inputs) |
|
|
logits = outputs.logits |
|
|
return "recherche" if torch.argmax(logits, dim=-1).item() == 1 else "gpt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_and_respond(text): |
|
|
steps = [] |
|
|
original_lang = detect_language(text) |
|
|
text_en = translate(text, "en") |
|
|
|
|
|
|
|
|
category = predict_category(text_en) |
|
|
steps.append("Catégorie détectée : " + category) |
|
|
|
|
|
if category == "recherche": |
|
|
response = search_duckduckgo(text_en) |
|
|
final_response = "\n".join([translate(r, original_lang) for r in response]) |
|
|
steps.append("Résultats DuckDuckGo récupérés") |
|
|
return { |
|
|
"response": final_response, |
|
|
"response_type": "recherche", |
|
|
"emotions": None, |
|
|
"steps": steps |
|
|
} |
|
|
|
|
|
|
|
|
compound, is_unacceptable, emotion = classify_emotion(text_en) |
|
|
steps.append(f"Émotion détectée : {emotion} (score={compound:.2f})") |
|
|
|
|
|
if is_unacceptable and abs(compound) > 50: |
|
|
final_response = translate("Je ressens beaucoup de tension dans votre message.", original_lang) |
|
|
steps.append("Réponse émotion inacceptable envoyée") |
|
|
return { |
|
|
"response": final_response, |
|
|
"response_type": "non acceptable", |
|
|
"emotions": emotion, |
|
|
"steps": steps |
|
|
} |
|
|
|
|
|
|
|
|
gpt_response = generate_response(text_en) |
|
|
final_response = translate(gpt_response, original_lang) |
|
|
steps.append("Réponse GPT générée et traduite") |
|
|
|
|
|
return { |
|
|
"response": final_response, |
|
|
"response_type": "gpt", |
|
|
"emotions": emotion, |
|
|
"steps": steps |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class RequestBody(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict_api(body: RequestBody): |
|
|
return classify_and_respond(body.text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|