Spaces:
Running
Running
| import re | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from rag_gpt2 import RAGEngineGPT2 | |
| # ───────── INITIALISATION RAG ───────── | |
| rag = RAGEngineGPT2("corpus/") | |
| # ───────── MODÈLE GPT-2 ───────── | |
| # Critique : GPT-2 standard est surtout anglophone et moins bon en instruction. | |
| # On l'utilise ici pour comparer avec FLAN-T5, pas comme meilleur modèle. | |
| model_name = "dbddv01/gpt2-french-small" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def generate_gpt2(prompt, max_new_tokens=120, temperature=0.7): | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=768 | |
| ) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.9, | |
| repetition_penalty=1.25, | |
| no_repeat_ngram_size=3, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # On enlève le prompt recopié par GPT-2 | |
| if decoded.startswith(prompt): | |
| decoded = decoded[len(prompt):] | |
| return decoded.strip() | |
| def is_bad_generation(text: str) -> bool: | |
| if not text or len(text.strip()) < 20: | |
| return True | |
| bad_patterns = [ | |
| "Question:", | |
| "Questions:", | |
| "Réponse en français", | |
| "Contexte:", | |
| "http", | |
| ] | |
| if any(p.lower() in text.lower() for p in bad_patterns): | |
| return True | |
| words = text.split() | |
| if len(words) > 10: | |
| unique_ratio = len(set(words)) / len(words) | |
| if unique_ratio < 0.45: | |
| return True | |
| return False | |
| # ───────── RÉPONSE EXTRACTIVE CONTRÔLÉE ───────── | |
| def extractive_answer(question, sources_found): | |
| best = sources_found[0] | |
| text = rag._clean_for_prompt(best.text) | |
| sentences = [ | |
| s.strip() | |
| for s in text.split(".") | |
| if len(s.strip()) > 35 | |
| ] | |
| question_lower = question.lower() | |
| if "risque" in question_lower: | |
| keywords = [ | |
| "risque", "biais", "erreur", "erreurs", | |
| "confidentialité", "données", "sécurité", | |
| "RGPD", "discrimination", "supervision" | |
| ] | |
| elif "application" in question_lower or "utilisation" in question_lower: | |
| keywords = [ | |
| "utilisé", "utilisée", "applications", | |
| "diagnostic", "analyse", "détection", | |
| "générer", "prédire" | |
| ] | |
| elif "avantage" in question_lower or "importance" in question_lower: | |
| keywords = [ | |
| "améliorer", "optimiser", "aider", | |
| "utile", "rapidité", "permet" | |
| ] | |
| elif "définition" in question_lower or "c'est quoi" in question_lower or "qu'est-ce" in question_lower: | |
| keywords = [ | |
| "désigne", "définition", "est", "sont" | |
| ] | |
| else: | |
| keywords = [] | |
| if keywords: | |
| selected = [ | |
| s for s in sentences | |
| if any(k.lower() in s.lower() for k in keywords) | |
| ] | |
| else: | |
| selected = sentences[:4] | |
| if not selected: | |
| selected = sentences[:4] | |
| answer = ". ".join(selected[:4]).strip() | |
| if not answer.endswith("."): | |
| answer += "." | |
| return answer | |
| # ───────── RÉPONSE RAG + GPT-2 AVEC SÉCURITÉ ───────── | |
| def gpt2_rag_answer(question, sources_found, temperature): | |
| base_answer = extractive_answer(question, sources_found) | |
| best = sources_found[0] | |
| rewrite_prompt = f""" | |
| Réécris ce texte en français clair et naturel sans ajouter d'information. | |
| Texte : | |
| {base_answer} | |
| Réponse : | |
| """ | |
| generated = generate_gpt2( | |
| rewrite_prompt, | |
| max_new_tokens=100, | |
| temperature=temperature | |
| ) | |
| # GPT-2 est instable : si la génération est mauvaise, on garde la réponse extractive. | |
| if is_bad_generation(generated): | |
| final_answer = base_answer | |
| else: | |
| final_answer = generated | |
| return final_answer + f"\n\nSource utilisée : {best.filename}" | |
| # ───────── SENTIMENT ───────── | |
| sentiment = pipeline( | |
| "sentiment-analysis", | |
| model="nlptown/bert-base-multilingual-uncased-sentiment" | |
| ) | |
| # ───────── CHATBOT RAG GPT-2 ───────── | |
| def chatbot_response(message, history, temperature, top_k, use_gpt2_rewrite): | |
| if history is None: | |
| history = [] | |
| if not message or not message.strip(): | |
| return history, "", "Veuillez poser une question." | |
| # Pour stabiliser la réponse RAG, on utilise le meilleur passage. | |
| sources_found = rag.search(message, top_k=1) | |
| if not sources_found: | |
| answer = ( | |
| "Je ne dispose pas d’assez d’informations dans le corpus " | |
| "pour répondre correctement à cette question." | |
| ) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, "", "Aucune source suffisamment pertinente trouvée." | |
| if use_gpt2_rewrite: | |
| answer = gpt2_rag_answer(message, sources_found, temperature) | |
| else: | |
| answer = extractive_answer(message, sources_found) | |
| answer += f"\n\nSource utilisée : {sources_found[0].filename}" | |
| source_text = "" | |
| for src in sources_found: | |
| source_text += f"### {src.filename} — score : {src.score:.3f}\n" | |
| source_text += f"{src.text[:450]}...\n\n" | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, "", source_text | |
| # ───────── EXPLORATION CORPUS ───────── | |
| def search_corpus(query, top_k): | |
| if not query or not query.strip(): | |
| return "Entrez une requête de recherche." | |
| results = rag.search(query, top_k=top_k) | |
| if not results: | |
| return "Aucun passage pertinent trouvé dans le corpus." | |
| output = "" | |
| for result in results: | |
| output += f"## {result.filename} — score : {result.score:.3f}\n" | |
| output += f"{result.text[:900]}...\n\n" | |
| return output | |
| # ───────── SENTIMENT ───────── | |
| def analyse_sentiment(text): | |
| if not text or not text.strip(): | |
| return "Entrez un texte à analyser." | |
| result = sentiment(text)[0] | |
| return f""" | |
| Résultat : **{result['label']}** | |
| Score de confiance : **{result['score']:.2f}** | |
| """ | |
| # ───────── GÉNÉRATION LIBRE GPT-2 ───────── | |
| def free_generation(prompt, temperature): | |
| if not prompt or not prompt.strip(): | |
| return "Entrez un prompt." | |
| generated = generate_gpt2( | |
| prompt, | |
| max_new_tokens=180, | |
| temperature=temperature | |
| ) | |
| return generated if generated else "GPT-2 n'a pas généré de réponse exploitable." | |
| # ───────── INTERFACE GRADIO ───────── | |
| with gr.Blocks(title="Assistant RAG GPT-2 — IA Générative") as app: | |
| gr.Markdown("# Assistant RAG GPT-2 — IA Générative") | |
| gr.Markdown( | |
| """ | |
| Cette version teste GPT-2 avec le même corpus documentaire. | |
| **Critique importante :** GPT-2 est moins fiable que FLAN-T5 pour suivre des consignes. | |
| Le mode RAG utilise donc une réponse extractive contrôlée, avec une option de reformulation par GPT-2. | |
| """ | |
| ) | |
| with gr.Tab("Chatbot RAG GPT-2"): | |
| chatbot = gr.Chatbot(label="Conversation") | |
| question = gr.Textbox( | |
| label="Votre question", | |
| placeholder="Exemple : Quels sont les risques de l'IA dans la médecine ?", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| label="Température GPT-2" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| label="Nombre de passages utilisés" | |
| ) | |
| use_gpt2_rewrite = gr.Checkbox( | |
| label="Utiliser GPT-2 pour reformuler la réponse", | |
| value=False | |
| ) | |
| send_btn = gr.Button("Envoyer", variant="primary") | |
| reset_btn = gr.Button("Réinitialiser la conversation") | |
| sources = gr.Markdown(label="Sources utilisées") | |
| send_btn.click( | |
| chatbot_response, | |
| inputs=[question, chatbot, temperature, top_k, use_gpt2_rewrite], | |
| outputs=[chatbot, question, sources] | |
| ) | |
| reset_btn.click( | |
| lambda: ([], "", ""), | |
| outputs=[chatbot, question, sources] | |
| ) | |
| with gr.Tab("Explorer le corpus"): | |
| gr.Markdown( | |
| f""" | |
| **Nombre de fichiers chargés :** {len(set(rag.chunk_sources))} | |
| **Nombre de passages indexés :** {len(rag.chunks)} | |
| """ | |
| ) | |
| search_query = gr.Textbox( | |
| label="Recherche dans la base de connaissances", | |
| placeholder="Exemple : IA et finance" | |
| ) | |
| search_top_k = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=3, | |
| step=1, | |
| label="Nombre de résultats" | |
| ) | |
| search_btn = gr.Button("Rechercher") | |
| search_output = gr.Markdown() | |
| search_btn.click( | |
| search_corpus, | |
| inputs=[search_query, search_top_k], | |
| outputs=search_output | |
| ) | |
| with gr.Tab("Analyse de sentiment"): | |
| sent_text = gr.Textbox( | |
| label="Texte à analyser", | |
| placeholder="Exemple : Ce projet est très intéressant et utile.", | |
| lines=4 | |
| ) | |
| sent_btn = gr.Button("Analyser") | |
| sent_output = gr.Markdown() | |
| sent_btn.click( | |
| analyse_sentiment, | |
| inputs=sent_text, | |
| outputs=sent_output | |
| ) | |
| with gr.Tab("Génération libre GPT-2"): | |
| gr.Markdown( | |
| """ | |
| Cet onglet montre GPT-2 sans RAG. | |
| Il permet de comparer une génération libre avec une réponse appuyée sur le corpus. | |
| """ | |
| ) | |
| free_prompt = gr.Textbox( | |
| label="Prompt libre", | |
| placeholder="Exemple : Artificial intelligence in medicine is", | |
| lines=4 | |
| ) | |
| free_temp = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Température" | |
| ) | |
| free_btn = gr.Button("Générer") | |
| free_output = gr.Textbox(label="Réponse GPT-2", lines=8) | |
| free_btn.click( | |
| free_generation, | |
| inputs=[free_prompt, free_temp], | |
| outputs=free_output | |
| ) | |
| app.launch() |