Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| import os | |
| import time | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE | |
| import torch | |
| import spaces | |
| # URL du backend Hugging Face Space | |
| API_URL = os.getenv('API_URL') | |
| API_URL_ASK = API_URL+"/get_answer" | |
| SPACE_URL = "https://huggingface.co/api/spaces/Loren/api_search_articles" | |
| # Chargement du modèle génératif | |
| MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| def get_tags(): | |
| resp = requests.get(f"{API_URL}/get_tags") | |
| if resp.status_code != 200: | |
| raise gr.Error(f"❌ Erreur : {resp.status_code}") | |
| return None | |
| dict_resp = resp.json() | |
| if dict_resp["status"] != "ok": | |
| raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") | |
| return None | |
| list_tags = dict_resp["tags"] | |
| return sorted(list_tags, key=str.lower) | |
| def get_articles_with_tags(tags): | |
| if not tags: | |
| raise gr.Error("❌ Erreur : aucun tag fourni.") | |
| return None | |
| params = {"tags": tags} | |
| resp = requests.get(f"{API_URL}/get_articles_with_tags", params=params) | |
| if resp.status_code != 200: | |
| raise gr.Error(f"❌ Erreur : {resp.status_code}") | |
| return None | |
| dict_resp = resp.json() | |
| if dict_resp["status"] != "ok": | |
| raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") | |
| return None | |
| df = pd.DataFrame(dict_resp["articles"]) | |
| # Convertir les URLs en liens HTML cliquables | |
| df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>') | |
| #df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1) | |
| # Convertir le DataFrame en table HTML stylisée | |
| html = df[['article_title', 'article_url']].to_html(escape=False, index=False, | |
| header=False, # pas d'en-tête | |
| border=0 # pas de bordure | |
| ) | |
| # Créer des liens cliquables avec le titre | |
| #df["link"] = df.apply( | |
| # lambda r: f'<a href="{r.article_url}" target="_blank" style="text-decoration:none; color:#2563eb;">{r.article_title}</a>', | |
| # axis=1) | |
| #df['link'] = df.apply(lambda row: | |
| # f'<a href="{row["article_url"]}" target="_blank">{row["article_url"]}</a>' | |
| # if row['article_url'] != 'Article unavailable', | |
| # axis=1) | |
| ## Convertir uniquement la colonne des liens en HTML sans bordures ni en-tête | |
| #html = df[["link"]].to_html(escape=False, index=False, header=False, border=0) | |
| # Ajouter un peu de style CSS pour un look propre | |
| # CSS + scroll vertical limité à 10 lignes environ | |
| styled_html = f""" | |
| <style> | |
| table {{ | |
| border-collapse: collapse; | |
| width: 100%; | |
| }} | |
| td {{ | |
| padding: 6px 0; | |
| }} | |
| a:hover {{ | |
| text-decoration: underline; | |
| color: #1d4ed8; | |
| }} | |
| </style> | |
| <div style="max-height: 300px; overflow-y: auto;"> <!-- hauteur ≈ 10 lignes --> | |
| {html} | |
| </div> | |
| """ | |
| # Style minimal + conteneur scrollable | |
| styled_html = f""" | |
| <style> | |
| table {{ | |
| border-collapse: collapse; | |
| width: 100%; | |
| font-family: sans-serif; | |
| font-size: 14px; | |
| }} | |
| td {{ | |
| padding: 6px 0; | |
| border: none; | |
| }} | |
| tr:hover {{ | |
| background-color: #f9fafb; | |
| }} | |
| a {{ | |
| color: #2563eb; | |
| text-decoration: none; | |
| }} | |
| a:hover {{ | |
| text-decoration: underline; | |
| color: #1d4ed8; | |
| }} | |
| </style> | |
| <div style="max-height: 300px; overflow-y: auto;"> | |
| {html} | |
| </div> | |
| """ | |
| return styled_html | |
| def get_articles_with_query(query, use_rerank): | |
| if not query: | |
| raise gr.Error("❌ Erreur : aucun query fourni.") | |
| return None | |
| params = {"query": query, "use_rerank": use_rerank} | |
| resp = requests.get(f"{API_URL}/get_query_results", params=params) | |
| if resp.status_code != 200: | |
| raise gr.Error(f"❌ Erreur : {resp.status_code}") | |
| return None | |
| dict_resp = resp.json() | |
| if dict_resp["status"] != "ok": | |
| raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") | |
| return None | |
| df = pd.DataFrame(dict_resp["results"]) | |
| # Convertir les URLs en liens HTML cliquables | |
| df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>') | |
| #df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1) | |
| # Convertir le DataFrame en table HTML stylisée | |
| html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \ | |
| .to_html(escape=False, index=False, | |
| header=True, | |
| border=0 # pas de bordure | |
| ) | |
| # Style minimal + conteneur scrollable | |
| styled_html = f""" | |
| <style> | |
| table {{ | |
| border-collapse: collapse; | |
| width: 100%; | |
| font-family: sans-serif; | |
| font-size: 14px; | |
| white-space: nowrap; | |
| }} | |
| th {{ | |
| text-align: left; | |
| padding: 6px 0; | |
| border-bottom: 1px solid #555; /* couleur adaptée au dark */ | |
| color: #eee; /* texte clair pour dark */ | |
| }} | |
| td {{ | |
| padding: 6px 0; | |
| border: none; | |
| color: #eee; /* texte clair */ | |
| }} | |
| tr:hover {{ | |
| background-color: rgba(255, 255, 255, 0.1); /* léger survol adapté au dark */ | |
| }} | |
| a {{ | |
| color: #4ea3ff; /* lien visible en dark */ | |
| text-decoration: none; | |
| }} | |
| a:hover {{ | |
| text-decoration: underline; | |
| color: #1d4ed8; | |
| }} | |
| </style> | |
| <div style="max-height: 300px; overflow-y: auto; overflow-x: auto;"> | |
| {html} | |
| </div> | |
| """ | |
| return styled_html | |
| def get_answer_with_query(query, use_rerank, history): | |
| if not query: | |
| raise gr.Error("❌ Erreur : aucun query fourni.") | |
| return None | |
| params = {"query": query, "use_rerank": use_rerank} | |
| answer = "" | |
| resp = requests.get(f"{API_URL}/get_query_results", params=params) | |
| if resp.status_code != 200: | |
| raise gr.Error(f"❌ Erreur : {resp.status_code}") | |
| return None | |
| dict_resp = resp.json() | |
| if dict_resp["status"] != "ok": | |
| raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") | |
| return None | |
| list_chunks = [resp['chunk_text'] for resp in dict_resp['results']] | |
| if not list_chunks: | |
| answer = ("Je ne dispose pas d’informations sur ce sujet. " | |
| "Je peux uniquement répondre à des questions sur les articles " \ | |
| "du jeu de données.") | |
| else: | |
| # Construction du prompt | |
| prompt = RAG_PROMPT_TEMPLATE.format( | |
| context="\n".join(list_chunks), | |
| question=query | |
| ) | |
| # Génération de la réponse | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=500) | |
| generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée | |
| answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| df = pd.DataFrame(dict_resp["results"]) | |
| # Convertir les URLs en liens HTML cliquables | |
| df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>') | |
| #df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1) | |
| # Convertir le DataFrame en table HTML stylisée | |
| html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \ | |
| .to_html(escape=False, index=False, | |
| header=True, | |
| border=0 # pas de bordure | |
| ) | |
| # Style minimal + conteneur scrollable | |
| styled_html = f""" | |
| <style> | |
| table {{ | |
| border-collapse: collapse; | |
| width: 100%; | |
| font-family: sans-serif; | |
| font-size: 14px; | |
| white-space: nowrap; | |
| }} | |
| th {{ | |
| text-align: left; | |
| padding: 6px 0; | |
| border-bottom: 1px solid #555; /* couleur adaptée au dark */ | |
| color: #eee; /* texte clair pour dark */ | |
| }} | |
| td {{ | |
| padding: 6px 0; | |
| border: none; | |
| color: #eee; /* texte clair */ | |
| }} | |
| tr:hover {{ | |
| background-color: rgba(255, 255, 255, 0.1); /* léger survol adapté au dark */ | |
| }} | |
| a {{ | |
| color: #4ea3ff; /* lien visible en dark */ | |
| text-decoration: none; | |
| }} | |
| a:hover {{ | |
| text-decoration: underline; | |
| color: #1d4ed8; | |
| }} | |
| </style> | |
| <div style="max-height: 300px; overflow-y: auto; overflow-x: auto;"> | |
| {html} | |
| </div> | |
| """ | |
| history.append((query, answer)) | |
| return "", history, styled_html | |
| # -------- FONCTION DE CHECK/RESTART -------- | |
| def check_or_restart_space(): | |
| try: | |
| r = requests.get(SPACE_URL) | |
| data = r.json() | |
| return "### "+data.get("runtime", {}).get("stage") | |
| except Exception as e: | |
| return "### Space in error : "+str(e) | |
| ## | |
| custom_css = """ | |
| .tab-nav button, .tab-nav button span, .tab-nav button p { | |
| color: orange !important; | |
| font-weight: bold !important; | |
| font-size: 1.3rem !important; | |
| } | |
| """ | |
| css = """ | |
| #mini_btn { | |
| width: fit-content !important; | |
| display: inline-block !important; | |
| margin: 0 !important; | |
| } | |
| """ | |
| list_tags = get_tags() | |
| # --- Interface Gradio --- | |
| with gr.Blocks(title="Search articles", css=custom_css) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Image(value="dataset-cover.jpg", type="filepath") | |
| with gr.Column(): | |
| gr.HTML("<h1 style='text-align: center; font-weight: bold;'>190k+ Medium Articles</h1>") | |
| gr.Markdown("## 🔎 Search in this dataset hosted in Kaggle ([Dataset details]({https://www.kaggle.com/datasets/fabiochiusano/medium-articles}))") | |
| with gr.Row(): | |
| status_btn = gr.Button("Check / restart API", elem_id="mini_btn") | |
| status_output = gr.Markdown() | |
| status_btn.click(fn=check_or_restart_space, outputs=status_output) | |
| gr.Markdown(" ") | |
| with gr.Tab("by tags"): | |
| gr.Markdown("### Select one or more tags and click the button to see the articles with those tags.") | |
| dropdown = gr.Dropdown( | |
| choices=list_tags, | |
| label="Selected tags :", | |
| multiselect=True | |
| ) | |
| # Bouton pour lancer la fonction | |
| submit_btn1 = gr.Button("Find articles", variant="primary") | |
| # La liste des articles | |
| html_table = gr.HTML() | |
| # Connecter le bouton à la fonction | |
| submit_btn1.click(fn=get_articles_with_tags, inputs=dropdown, outputs=html_table) | |
| with gr.Tab("by query to find relevant articles"): | |
| gr.Markdown("### Enter query and click the button to see the relevant articles.") | |
| search_box = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?") | |
| use_rerank1 = gr.Checkbox(label="Use reranker?") | |
| # Bouton pour lancer la fonction | |
| submit_btn2 = gr.Button("Find articles", variant="primary") | |
| # La liste des articles | |
| html_results = gr.HTML() | |
| # Connecter le bouton à la fonction | |
| submit_btn2.click(fn=get_articles_with_query, inputs=[search_box, use_rerank1], outputs=html_results) | |
| with gr.Tab("by query to get answer"): | |
| gr.Markdown("### Enter query and click the button to see the relevant articles.") | |
| use_rerank2 = gr.Checkbox(label="Use reranker?") | |
| msg = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?") | |
| clear = gr.Button("Clear", variant="primary") | |
| chatbot = gr.Chatbot() | |
| # La liste des articles | |
| html_details = gr.HTML() | |
| msg.submit(get_answer_with_query, inputs=[msg, use_rerank2, chatbot], | |
| outputs=[msg, chatbot, html_details]) | |
| clear.click(lambda: [None, False, None, None], \ | |
| outputs=[msg, use_rerank2, chatbot, html_details]) | |
| ### Launch the app | |
| demo.queue().launch(debug=True) |