search_articles / app.py
Loren's picture
Update app.py
2f3974b verified
import gradio as gr
import requests
import pandas as pd
import os
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from templates.prompt_mistral_rag import RAG_PROMPT_TEMPLATE
import torch
import spaces
# URL du backend Hugging Face Space
API_URL = os.getenv('API_URL')
API_URL_ASK = API_URL+"/get_answer"
SPACE_URL = "https://huggingface.co/api/spaces/Loren/api_search_articles"
# Chargement du modèle génératif
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto"
)
def get_tags():
resp = requests.get(f"{API_URL}/get_tags")
if resp.status_code != 200:
raise gr.Error(f"❌ Erreur : {resp.status_code}")
return None
dict_resp = resp.json()
if dict_resp["status"] != "ok":
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
return None
list_tags = dict_resp["tags"]
return sorted(list_tags, key=str.lower)
def get_articles_with_tags(tags):
if not tags:
raise gr.Error("❌ Erreur : aucun tag fourni.")
return None
params = {"tags": tags}
resp = requests.get(f"{API_URL}/get_articles_with_tags", params=params)
if resp.status_code != 200:
raise gr.Error(f"❌ Erreur : {resp.status_code}")
return None
dict_resp = resp.json()
if dict_resp["status"] != "ok":
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
return None
df = pd.DataFrame(dict_resp["articles"])
# Convertir les URLs en liens HTML cliquables
df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>')
#df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1)
# Convertir le DataFrame en table HTML stylisée
html = df[['article_title', 'article_url']].to_html(escape=False, index=False,
header=False, # pas d'en-tête
border=0 # pas de bordure
)
# Créer des liens cliquables avec le titre
#df["link"] = df.apply(
# lambda r: f'<a href="{r.article_url}" target="_blank" style="text-decoration:none; color:#2563eb;">{r.article_title}</a>',
# axis=1)
#df['link'] = df.apply(lambda row:
# f'<a href="{row["article_url"]}" target="_blank">{row["article_url"]}</a>'
# if row['article_url'] != 'Article unavailable',
# axis=1)
## Convertir uniquement la colonne des liens en HTML sans bordures ni en-tête
#html = df[["link"]].to_html(escape=False, index=False, header=False, border=0)
# Ajouter un peu de style CSS pour un look propre
# CSS + scroll vertical limité à 10 lignes environ
styled_html = f"""
<style>
table {{
border-collapse: collapse;
width: 100%;
}}
td {{
padding: 6px 0;
}}
a:hover {{
text-decoration: underline;
color: #1d4ed8;
}}
</style>
<div style="max-height: 300px; overflow-y: auto;"> <!-- hauteur ≈ 10 lignes -->
{html}
</div>
"""
# Style minimal + conteneur scrollable
styled_html = f"""
<style>
table {{
border-collapse: collapse;
width: 100%;
font-family: sans-serif;
font-size: 14px;
}}
td {{
padding: 6px 0;
border: none;
}}
tr:hover {{
background-color: #f9fafb;
}}
a {{
color: #2563eb;
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
color: #1d4ed8;
}}
</style>
<div style="max-height: 300px; overflow-y: auto;">
{html}
</div>
"""
return styled_html
def get_articles_with_query(query, use_rerank):
if not query:
raise gr.Error("❌ Erreur : aucun query fourni.")
return None
params = {"query": query, "use_rerank": use_rerank}
resp = requests.get(f"{API_URL}/get_query_results", params=params)
if resp.status_code != 200:
raise gr.Error(f"❌ Erreur : {resp.status_code}")
return None
dict_resp = resp.json()
if dict_resp["status"] != "ok":
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
return None
df = pd.DataFrame(dict_resp["results"])
# Convertir les URLs en liens HTML cliquables
df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>')
#df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1)
# Convertir le DataFrame en table HTML stylisée
html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \
.to_html(escape=False, index=False,
header=True,
border=0 # pas de bordure
)
# Style minimal + conteneur scrollable
styled_html = f"""
<style>
table {{
border-collapse: collapse;
width: 100%;
font-family: sans-serif;
font-size: 14px;
white-space: nowrap;
}}
th {{
text-align: left;
padding: 6px 0;
border-bottom: 1px solid #555; /* couleur adaptée au dark */
color: #eee; /* texte clair pour dark */
}}
td {{
padding: 6px 0;
border: none;
color: #eee; /* texte clair */
}}
tr:hover {{
background-color: rgba(255, 255, 255, 0.1); /* léger survol adapté au dark */
}}
a {{
color: #4ea3ff; /* lien visible en dark */
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
color: #1d4ed8;
}}
</style>
<div style="max-height: 300px; overflow-y: auto; overflow-x: auto;">
{html}
</div>
"""
return styled_html
@spaces.GPU
def get_answer_with_query(query, use_rerank, history):
if not query:
raise gr.Error("❌ Erreur : aucun query fourni.")
return None
params = {"query": query, "use_rerank": use_rerank}
answer = ""
resp = requests.get(f"{API_URL}/get_query_results", params=params)
if resp.status_code != 200:
raise gr.Error(f"❌ Erreur : {resp.status_code}")
return None
dict_resp = resp.json()
if dict_resp["status"] != "ok":
raise gr.Error(f"❌ Erreur : {dict_resp['code']} - {dict_resp['message']}")
return None
list_chunks = [resp['chunk_text'] for resp in dict_resp['results']]
if not list_chunks:
answer = ("Je ne dispose pas d’informations sur ce sujet. "
"Je peux uniquement répondre à des questions sur les articles " \
"du jeu de données.")
else:
# Construction du prompt
prompt = RAG_PROMPT_TEMPLATE.format(
context="\n".join(list_chunks),
question=query
)
# Génération de la réponse
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=500)
generated_tokens = outputs[0][inputs["input_ids"].shape[-1]:] # uniquement la partie générée
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
df = pd.DataFrame(dict_resp["results"])
# Convertir les URLs en liens HTML cliquables
df["article_url"] = df["article_url"].apply(lambda x: f'<a href="{x}" target="_blank">Ouvrir</a>')
#df["article_title"] = df.apply(lambda r: f'<a href="{r.article_url}" target="_blank">{r.article_title}</a>', axis=1)
# Convertir le DataFrame en table HTML stylisée
html = df[['chunk_text', 'distance', 'article_title', 'tags', 'article_url']] \
.to_html(escape=False, index=False,
header=True,
border=0 # pas de bordure
)
# Style minimal + conteneur scrollable
styled_html = f"""
<style>
table {{
border-collapse: collapse;
width: 100%;
font-family: sans-serif;
font-size: 14px;
white-space: nowrap;
}}
th {{
text-align: left;
padding: 6px 0;
border-bottom: 1px solid #555; /* couleur adaptée au dark */
color: #eee; /* texte clair pour dark */
}}
td {{
padding: 6px 0;
border: none;
color: #eee; /* texte clair */
}}
tr:hover {{
background-color: rgba(255, 255, 255, 0.1); /* léger survol adapté au dark */
}}
a {{
color: #4ea3ff; /* lien visible en dark */
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
color: #1d4ed8;
}}
</style>
<div style="max-height: 300px; overflow-y: auto; overflow-x: auto;">
{html}
</div>
"""
history.append((query, answer))
return "", history, styled_html
# -------- FONCTION DE CHECK/RESTART --------
def check_or_restart_space():
try:
r = requests.get(SPACE_URL)
data = r.json()
return "### "+data.get("runtime", {}).get("stage")
except Exception as e:
return "### Space in error : "+str(e)
##
custom_css = """
.tab-nav button, .tab-nav button span, .tab-nav button p {
color: orange !important;
font-weight: bold !important;
font-size: 1.3rem !important;
}
"""
css = """
#mini_btn {
width: fit-content !important;
display: inline-block !important;
margin: 0 !important;
}
"""
list_tags = get_tags()
# --- Interface Gradio ---
with gr.Blocks(title="Search articles", css=custom_css) as demo:
with gr.Row():
with gr.Column():
gr.Image(value="dataset-cover.jpg", type="filepath")
with gr.Column():
gr.HTML("<h1 style='text-align: center; font-weight: bold;'>190k+ Medium Articles</h1>")
gr.Markdown("## 🔎 Search in this dataset hosted in Kaggle ([Dataset details]({https://www.kaggle.com/datasets/fabiochiusano/medium-articles}))")
with gr.Row():
status_btn = gr.Button("Check / restart API", elem_id="mini_btn")
status_output = gr.Markdown()
status_btn.click(fn=check_or_restart_space, outputs=status_output)
gr.Markdown(" ")
with gr.Tab("by tags"):
gr.Markdown("### Select one or more tags and click the button to see the articles with those tags.")
dropdown = gr.Dropdown(
choices=list_tags,
label="Selected tags :",
multiselect=True
)
# Bouton pour lancer la fonction
submit_btn1 = gr.Button("Find articles", variant="primary")
# La liste des articles
html_table = gr.HTML()
# Connecter le bouton à la fonction
submit_btn1.click(fn=get_articles_with_tags, inputs=dropdown, outputs=html_table)
with gr.Tab("by query to find relevant articles"):
gr.Markdown("### Enter query and click the button to see the relevant articles.")
search_box = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?")
use_rerank1 = gr.Checkbox(label="Use reranker?")
# Bouton pour lancer la fonction
submit_btn2 = gr.Button("Find articles", variant="primary")
# La liste des articles
html_results = gr.HTML()
# Connecter le bouton à la fonction
submit_btn2.click(fn=get_articles_with_query, inputs=[search_box, use_rerank1], outputs=html_results)
with gr.Tab("by query to get answer"):
gr.Markdown("### Enter query and click the button to see the relevant articles.")
use_rerank2 = gr.Checkbox(label="Use reranker?")
msg = gr.Textbox(label="Query", placeholder="Ex: What is Databricks?, Qui a fondé Databricks ?")
clear = gr.Button("Clear", variant="primary")
chatbot = gr.Chatbot()
# La liste des articles
html_details = gr.HTML()
msg.submit(get_answer_with_query, inputs=[msg, use_rerank2, chatbot],
outputs=[msg, chatbot, html_details])
clear.click(lambda: [None, False, None, None], \
outputs=[msg, use_rerank2, chatbot, html_details])
### Launch the app
demo.queue().launch(debug=True)