import gradio as gr import requests import pandas as pd import os import time from transformers import AutoModelForCausalLM, AutoTokenizer from templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE import torch import spaces # URL du backend Hugging Face Space API_URL = os.getenv('API_URL') API_URL_ASK = API_URL+"/get_answer" SPACE_URL = "https://huggingface.co/api/spaces/Loren/api_search_articles" # Chargement du modèle génératif MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto" ) def get_tags(): resp = requests.get(f"{API_URL}/get_tags") if resp.status_code != 200: raise gr.Error(f"❌ Erreur : {resp.status_code}") return None dict_resp = resp.json() if dict_resp["status"] != "ok": raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") return None list_tags = dict_resp["tags"] return sorted(list_tags, key=str.lower) def get_articles_with_tags(tags): if not tags: raise gr.Error("❌ Erreur : aucun tag fourni.") return None params = {"tags": tags} resp = requests.get(f"{API_URL}/get_articles_with_tags", params=params) if resp.status_code != 200: raise gr.Error(f"❌ Erreur : {resp.status_code}") return None dict_resp = resp.json() if dict_resp["status"] != "ok": raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") return None df = pd.DataFrame(dict_resp["articles"]) # Convertir les URLs en liens HTML cliquables df["article_url"] = df["article_url"].apply(lambda x: f'Ouvrir') #df["article_title"] = df.apply(lambda r: f'{r.article_title}', axis=1) # Convertir le DataFrame en table HTML stylisée html = df[['article_title', 'article_url']].to_html(escape=False, index=False, header=False, # pas d'en-tête border=0 # pas de bordure ) # Créer des liens cliquables avec le titre #df["link"] = df.apply( # lambda r: f'{r.article_title}', # axis=1) #df['link'] = df.apply(lambda row: # f'{row["article_url"]}' # if row['article_url'] != 'Article unavailable', # axis=1) ## Convertir uniquement la colonne des liens en HTML sans bordures ni en-tête #html = df[["link"]].to_html(escape=False, index=False, header=False, border=0) # Ajouter un peu de style CSS pour un look propre # CSS + scroll vertical limité à 10 lignes environ styled_html = f"""
{html}
""" # Style minimal + conteneur scrollable styled_html = f"""
{html}
""" return styled_html def get_articles_with_query(query, use_rerank): if not query: raise gr.Error("❌ Erreur : aucun query fourni.") return None params = {"query": query, "use_rerank": use_rerank} resp = requests.get(f"{API_URL}/get_query_results", params=params) if resp.status_code != 200: raise gr.Error(f"❌ Erreur : {resp.status_code}") return None dict_resp = resp.json() if dict_resp["status"] != "ok": raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") return None df = pd.DataFrame(dict_resp["results"]) # Convertir les URLs en liens HTML cliquables df["article_url"] = df["article_url"].apply(lambda x: f'Ouvrir') #df["article_title"] = df.apply(lambda r: f'{r.article_title}', axis=1) # Convertir le DataFrame en table HTML stylisée html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \ .to_html(escape=False, index=False, header=True, border=0 # pas de bordure ) # Style minimal + conteneur scrollable styled_html = f"""
{html}
""" return styled_html @spaces.GPU def get_answer_with_query(query, use_rerank, history): if not query: raise gr.Error("❌ Erreur : aucun query fourni.") return None params = {"query": query, "use_rerank": use_rerank} answer = "" resp = requests.get(f"{API_URL}/get_query_results", params=params) if resp.status_code != 200: raise gr.Error(f"❌ Erreur : {resp.status_code}") return None dict_resp = resp.json() if dict_resp["status"] != "ok": raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}") return None list_chunks = [resp['chunk_text'] for resp in dict_resp['results']] if not list_chunks: answer = ("Je ne dispose pas d’informations sur ce sujet. " "Je peux uniquement répondre à des questions sur les articles " \ "du jeu de données.") else: # Construction du prompt prompt = RAG_PROMPT_TEMPLATE.format( context="\n".join(list_chunks), question=query ) # Génération de la réponse inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=500) generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() df = pd.DataFrame(dict_resp["results"]) # Convertir les URLs en liens HTML cliquables df["article_url"] = df["article_url"].apply(lambda x: f'Ouvrir') #df["article_title"] = df.apply(lambda r: f'{r.article_title}', axis=1) # Convertir le DataFrame en table HTML stylisée html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \ .to_html(escape=False, index=False, header=True, border=0 # pas de bordure ) # Style minimal + conteneur scrollable styled_html = f"""
{html}
""" history.append((query, answer)) return "", history, styled_html # -------- FONCTION DE CHECK/RESTART -------- def check_or_restart_space(): try: r = requests.get(SPACE_URL) data = r.json() return "### "+data.get("runtime", {}).get("stage") except Exception as e: return "### Space in error : "+str(e) ## custom_css = """ .tab-nav button, .tab-nav button span, .tab-nav button p { color: orange !important; font-weight: bold !important; font-size: 1.3rem !important; } """ css = """ #mini_btn { width: fit-content !important; display: inline-block !important; margin: 0 !important; } """ list_tags = get_tags() # --- Interface Gradio --- with gr.Blocks(title="Search articles", css=custom_css) as demo: with gr.Row(): with gr.Column(): gr.Image(value="dataset-cover.jpg", type="filepath") with gr.Column(): gr.HTML("

190k+ Medium Articles

") gr.Markdown("## 🔎 Search in this dataset hosted in Kaggle ([Dataset details]({https://www.kaggle.com/datasets/fabiochiusano/medium-articles}))") with gr.Row(): status_btn = gr.Button("Check / restart API", elem_id="mini_btn") status_output = gr.Markdown() status_btn.click(fn=check_or_restart_space, outputs=status_output) gr.Markdown(" ") with gr.Tab("by tags"): gr.Markdown("### Select one or more tags and click the button to see the articles with those tags.") dropdown = gr.Dropdown( choices=list_tags, label="Selected tags :", multiselect=True ) # Bouton pour lancer la fonction submit_btn1 = gr.Button("Find articles", variant="primary") # La liste des articles html_table = gr.HTML() # Connecter le bouton à la fonction submit_btn1.click(fn=get_articles_with_tags, inputs=dropdown, outputs=html_table) with gr.Tab("by query to find relevant articles"): gr.Markdown("### Enter query and click the button to see the relevant articles.") search_box = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?") use_rerank1 = gr.Checkbox(label="Use reranker?") # Bouton pour lancer la fonction submit_btn2 = gr.Button("Find articles", variant="primary") # La liste des articles html_results = gr.HTML() # Connecter le bouton à la fonction submit_btn2.click(fn=get_articles_with_query, inputs=[search_box, use_rerank1], outputs=html_results) with gr.Tab("by query to get answer"): gr.Markdown("### Enter query and click the button to see the relevant articles.") use_rerank2 = gr.Checkbox(label="Use reranker?") msg = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?") clear = gr.Button("Clear", variant="primary") chatbot = gr.Chatbot() # La liste des articles html_details = gr.HTML() msg.submit(get_answer_with_query, inputs=[msg, use_rerank2, chatbot], outputs=[msg, chatbot, html_details]) clear.click(lambda: [None, False, None, None], \ outputs=[msg, use_rerank2, chatbot, html_details]) ### Launch the app demo.queue().launch(debug=True)