Chat-With-AI / app.py
NathMen12's picture
Upload 2 files
076222d verified
Raw
History Blame Contribute Delete
14.4 kB
"""
app.py — HuggingFace Space : Chat AI multi-modèle
- Authentification via HF OAuth (sans gr.LoginButton)
- Historique de conversation par utilisateur (SQLite)
- Sélection et chargement dynamique de modèles
- Inférence via ZeroGPU
- Paramètres : max_tokens, temperature, top_p, repetition_penalty, CPU mode
"""
import os
import json
import requests
import gradio as gr
import db
import model_runner
db.init_db()
SUGGESTED_MODELS = [
"microsoft/Phi-3.5-mini-instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
"google/gemma-2-2b-it",
"Qwen/Qwen2.5-7B-Instruct",
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
CSS = """
#sidebar { min-width: 260px; max-width: 280px; }
#conv-list { max-height: 60vh; overflow-y: auto; }
"""
# ── OAuth helpers ─────────────────────────────────────────────
def get_user_from_request(request: gr.Request) -> dict | None:
"""
HF injects the logged-in user info via the X-HF-User header (when
hf_oauth: true is set in README). Fall back to HF_TOKEN env var
for local dev.
"""
if request is None:
return None
# HuggingFace Space OAuth: header injected by the platform
hf_user = (request.headers or {}).get("x-hf-user") or \
(request.headers or {}).get("X-HF-User")
if hf_user:
try:
return json.loads(hf_user) # {"username": "...", "name": "..."}
except Exception:
return {"username": hf_user, "name": hf_user}
# Local dev fallback: use HF_TOKEN
token = os.environ.get("HF_TOKEN")
if token:
try:
resp = requests.get(
"https://huggingface.co/api/whoami",
headers={"Authorization": f"Bearer {token}"},
timeout=5,
)
if resp.ok:
data = resp.json()
return {"username": data.get("name", "user"), "name": data.get("fullname", data.get("name", "user"))}
except Exception:
pass
return None
# ── Helpers ───────────────────────────────────────────────────
def fmt_conv_label(c: dict) -> str:
n = c["msg_count"]
model_tag = f" [{c['model_id'].split('/')[-1]}]" if c.get("model_id") else ""
return f"{c['title']}{model_tag} ({n} msg)"
def messages_to_gradio(msgs: list[dict]) -> list[tuple]:
result = []
i = 0
while i < len(msgs):
if msgs[i]["role"] == "user":
user_msg = msgs[i]["content"]
bot_msg = msgs[i + 1]["content"] if i + 1 < len(msgs) and msgs[i + 1]["role"] == "assistant" else None
result.append((user_msg, bot_msg))
i += 2 if bot_msg is not None else 1
else:
i += 1
return result
# ── App ───────────────────────────────────────────────────────
def build_app():
with gr.Blocks(
title="🤗 AI Chat — Multi-Model",
css=CSS,
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
) as demo:
state_user = gr.State(None)
state_conv_id = gr.State(None)
state_history = gr.State([])
# ── Header ─────────────────────────────────────────────
with gr.Row(equal_height=True):
gr.Markdown("## 🤗 AI Chat — Multi-Modèle")
user_display = gr.Markdown("*Non connecté*")
gr.Markdown("---")
with gr.Row(equal_height=False):
# ── Sidebar ────────────────────────────────────────
with gr.Column(scale=1, elem_id="sidebar"):
gr.Markdown("### 💬 Conversations")
btn_new_conv = gr.Button("➕ Nouvelle conversation", variant="primary", size="sm")
conv_list = gr.Dataset(
components=["text"],
label="",
elem_id="conv-list",
headers=["Conversations"],
samples=[],
)
gr.Markdown("---")
gr.Markdown("### 🔧 Modèle")
model_input = gr.Dropdown(
choices=SUGGESTED_MODELS,
value=SUGGESTED_MODELS[0],
label="Modèle HuggingFace",
allow_custom_value=True,
info="Saisissez n'importe quel model_id HF",
)
with gr.Accordion("⚙️ Paramètres inférence", open=False):
use_cpu = gr.Checkbox(label="Utiliser le CPU (pas de GPU)", value=False)
use_4bit = gr.Checkbox(label="Quantification 4-bit (économise VRAM)", value=True)
max_tokens = gr.Slider(64, 4096, value=512, step=64, label="Max new tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty")
system_prompt = gr.Textbox(
label="System prompt",
placeholder="Tu es un assistant utile...",
lines=3,
)
model_status = gr.Markdown("**Statut :** aucun modèle chargé")
btn_load = gr.Button("⬇️ Charger le modèle", variant="secondary", size="sm")
btn_unload = gr.Button("🗑️ Décharger", size="sm")
# ── Chat ───────────────────────────────────────────
with gr.Column(scale=4):
chatbot = gr.Chatbot(
label="",
height=520,
show_label=False,
bubble_full_width=False,
show_copy_button=True,
avatar_images=(None, "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"),
)
with gr.Row():
msg_input = gr.Textbox(
placeholder="Écris ton message ici… (Entrée pour envoyer)",
show_label=False,
scale=9,
lines=1,
max_lines=6,
autofocus=True,
)
send_btn = gr.Button("Envoyer", variant="primary", scale=1)
with gr.Row():
conv_title_input = gr.Textbox(placeholder="Renommer la conversation…", show_label=False, scale=4)
btn_rename = gr.Button("✏️ Renommer", size="sm", scale=1)
btn_delete = gr.Button("🗑️ Supprimer", size="sm", scale=1, variant="stop")
# ── Handlers ───────────────────────────────────────────
def on_load(request: gr.Request):
user = get_user_from_request(request)
if user is None:
return None, None, [], gr.update(samples=[]), "*Non connecté — connectez-vous via HuggingFace*"
uid = user["username"]
convs = db.list_conversations(uid)
samples = [[fmt_conv_label(c)] for c in convs]
conv_id, history, chatbot_data = None, [], []
if convs:
conv_id = convs[0]["id"]
msgs = db.get_messages(conv_id)
history = [(m["role"], m["content"]) for m in msgs]
chatbot_data = messages_to_gradio(msgs)
return (
user,
conv_id,
history,
gr.update(samples=samples),
f"**Connecté :** {user['name']} (`{uid}`)",
)
demo.load(
on_load,
inputs=None,
outputs=[state_user, state_conv_id, state_history, conv_list, user_display],
)
# New conversation
def new_conv(user):
if user is None:
return None, [], [], gr.update()
uid = user["username"]
conv_id = db.create_conversation(uid)
convs = db.list_conversations(uid)
return conv_id, [], [], gr.update(samples=[[fmt_conv_label(c)] for c in convs])
btn_new_conv.click(new_conv, [state_user], [state_conv_id, state_history, chatbot, conv_list])
# Select conversation
def select_conv(evt: gr.SelectData, user):
if user is None:
return None, [], []
convs = db.list_conversations(user["username"])
if evt.index[0] >= len(convs):
return None, [], []
conv = convs[evt.index[0]]
msgs = db.get_messages(conv["id"])
history = [(m["role"], m["content"]) for m in msgs]
return conv["id"], history, messages_to_gradio(msgs)
conv_list.select(select_conv, [state_user], [state_conv_id, state_history, chatbot])
# Load model
def load_model_fn(model_id, cpu, quant4, user):
if user is None:
yield "⚠️ Connectez-vous d'abord."
return
if not model_id.strip():
yield "⚠️ Veuillez entrer un model_id valide."
return
try:
yield f"⏳ Chargement de `{model_id}`…"
model_runner.load_model(model_id.strip(), use_4bit=quant4, use_cpu=cpu)
yield f"✅ **Modèle chargé :** `{model_id}`"
except Exception as e:
yield f"❌ Erreur : {e}"
btn_load.click(load_model_fn, [model_input, use_cpu, use_4bit, state_user], [model_status])
def unload_fn():
model_runner._unload()
return "**Statut :** aucun modèle chargé"
btn_unload.click(unload_fn, outputs=[model_status])
# Send message
def user_message(user_msg, history, conv_id, user, model_id):
if user is None or not user_msg.strip():
yield history, conv_id, gr.update(value="")
return
uid = user["username"]
if conv_id is None:
conv_id = db.create_conversation(uid, db.auto_title_from_message(user_msg), model_id)
else:
db.update_conversation_model(conv_id, model_id)
db.add_message(conv_id, "user", user_msg)
history = history + [("user", user_msg)]
chatbot_data = messages_to_gradio([{"role": r, "content": c} for r, c in history])
yield chatbot_data, conv_id, gr.update(value="")
def bot_response(chatbot_data, history, conv_id, user, max_tok, temp, tp, rp, sysp):
if user is None or conv_id is None:
yield chatbot_data, history
return
if not model_runner.is_loaded():
chatbot_data[-1] = (chatbot_data[-1][0], "⚠️ Aucun modèle chargé.")
yield chatbot_data, history
return
context = [{"role": r, "content": c} for r, c in history[-20:]]
partial = ""
for chunk in model_runner.generate_stream(
messages=context,
max_new_tokens=int(max_tok),
temperature=float(temp),
top_p=float(tp),
repetition_penalty=float(rp),
system_prompt=sysp,
):
partial += chunk
chatbot_data[-1] = (chatbot_data[-1][0], partial)
yield chatbot_data, history
db.add_message(conv_id, "assistant", partial)
history = history + [("assistant", partial)]
yield chatbot_data, history
(msg_input.submit(
user_message,
[msg_input, state_history, state_conv_id, state_user, model_input],
[chatbot, state_conv_id, msg_input],
).then(
bot_response,
[chatbot, state_history, state_conv_id, state_user, max_tokens, temperature, top_p, rep_penalty, system_prompt],
[chatbot, state_history],
))
(send_btn.click(
user_message,
[msg_input, state_history, state_conv_id, state_user, model_input],
[chatbot, state_conv_id, msg_input],
).then(
bot_response,
[chatbot, state_history, state_conv_id, state_user, max_tokens, temperature, top_p, rep_penalty, system_prompt],
[chatbot, state_history],
))
# Rename
def rename_conv(conv_id, title, user):
if conv_id and user and title.strip():
db.rename_conversation(conv_id, user["username"], title.strip())
convs = db.list_conversations(user["username"])
return gr.update(samples=[[fmt_conv_label(c)] for c in convs]), gr.update(value="")
return gr.update(), gr.update()
btn_rename.click(rename_conv, [state_conv_id, conv_title_input, state_user], [conv_list, conv_title_input])
# Delete
def delete_conv(conv_id, user):
if conv_id and user:
db.delete_conversation(conv_id, user["username"])
convs = db.list_conversations(user["username"])
return None, [], [], gr.update(samples=[[fmt_conv_label(c)] for c in convs])
return conv_id, gr.update(), gr.update(), gr.update()
btn_delete.click(delete_conv, [state_conv_id, state_user], [state_conv_id, state_history, chatbot, conv_list])
return demo
if __name__ == "__main__":
demo = build_app()
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)),
share=False,
)