Spaces:
Runtime error
Runtime error
| """ | |
| app.py — HuggingFace Space : Chat AI multi-modèle | |
| - Authentification via HF OAuth (sans gr.LoginButton) | |
| - Historique de conversation par utilisateur (SQLite) | |
| - Sélection et chargement dynamique de modèles | |
| - Inférence via ZeroGPU | |
| - Paramètres : max_tokens, temperature, top_p, repetition_penalty, CPU mode | |
| """ | |
| import os | |
| import json | |
| import requests | |
| import gradio as gr | |
| import db | |
| import model_runner | |
| db.init_db() | |
| SUGGESTED_MODELS = [ | |
| "microsoft/Phi-3.5-mini-instruct", | |
| "meta-llama/Llama-3.2-3B-Instruct", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| "google/gemma-2-2b-it", | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| "HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| ] | |
| CSS = """ | |
| #sidebar { min-width: 260px; max-width: 280px; } | |
| #conv-list { max-height: 60vh; overflow-y: auto; } | |
| """ | |
| # ── OAuth helpers ───────────────────────────────────────────── | |
| def get_user_from_request(request: gr.Request) -> dict | None: | |
| """ | |
| HF injects the logged-in user info via the X-HF-User header (when | |
| hf_oauth: true is set in README). Fall back to HF_TOKEN env var | |
| for local dev. | |
| """ | |
| if request is None: | |
| return None | |
| # HuggingFace Space OAuth: header injected by the platform | |
| hf_user = (request.headers or {}).get("x-hf-user") or \ | |
| (request.headers or {}).get("X-HF-User") | |
| if hf_user: | |
| try: | |
| return json.loads(hf_user) # {"username": "...", "name": "..."} | |
| except Exception: | |
| return {"username": hf_user, "name": hf_user} | |
| # Local dev fallback: use HF_TOKEN | |
| token = os.environ.get("HF_TOKEN") | |
| if token: | |
| try: | |
| resp = requests.get( | |
| "https://huggingface.co/api/whoami", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| timeout=5, | |
| ) | |
| if resp.ok: | |
| data = resp.json() | |
| return {"username": data.get("name", "user"), "name": data.get("fullname", data.get("name", "user"))} | |
| except Exception: | |
| pass | |
| return None | |
| # ── Helpers ─────────────────────────────────────────────────── | |
| def fmt_conv_label(c: dict) -> str: | |
| n = c["msg_count"] | |
| model_tag = f" [{c['model_id'].split('/')[-1]}]" if c.get("model_id") else "" | |
| return f"{c['title']}{model_tag} ({n} msg)" | |
| def messages_to_gradio(msgs: list[dict]) -> list[tuple]: | |
| result = [] | |
| i = 0 | |
| while i < len(msgs): | |
| if msgs[i]["role"] == "user": | |
| user_msg = msgs[i]["content"] | |
| bot_msg = msgs[i + 1]["content"] if i + 1 < len(msgs) and msgs[i + 1]["role"] == "assistant" else None | |
| result.append((user_msg, bot_msg)) | |
| i += 2 if bot_msg is not None else 1 | |
| else: | |
| i += 1 | |
| return result | |
| # ── App ─────────────────────────────────────────────────────── | |
| def build_app(): | |
| with gr.Blocks( | |
| title="🤗 AI Chat — Multi-Model", | |
| css=CSS, | |
| theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"), | |
| ) as demo: | |
| state_user = gr.State(None) | |
| state_conv_id = gr.State(None) | |
| state_history = gr.State([]) | |
| # ── Header ───────────────────────────────────────────── | |
| with gr.Row(equal_height=True): | |
| gr.Markdown("## 🤗 AI Chat — Multi-Modèle") | |
| user_display = gr.Markdown("*Non connecté*") | |
| gr.Markdown("---") | |
| with gr.Row(equal_height=False): | |
| # ── Sidebar ──────────────────────────────────────── | |
| with gr.Column(scale=1, elem_id="sidebar"): | |
| gr.Markdown("### 💬 Conversations") | |
| btn_new_conv = gr.Button("➕ Nouvelle conversation", variant="primary", size="sm") | |
| conv_list = gr.Dataset( | |
| components=["text"], | |
| label="", | |
| elem_id="conv-list", | |
| headers=["Conversations"], | |
| samples=[], | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 🔧 Modèle") | |
| model_input = gr.Dropdown( | |
| choices=SUGGESTED_MODELS, | |
| value=SUGGESTED_MODELS[0], | |
| label="Modèle HuggingFace", | |
| allow_custom_value=True, | |
| info="Saisissez n'importe quel model_id HF", | |
| ) | |
| with gr.Accordion("⚙️ Paramètres inférence", open=False): | |
| use_cpu = gr.Checkbox(label="Utiliser le CPU (pas de GPU)", value=False) | |
| use_4bit = gr.Checkbox(label="Quantification 4-bit (économise VRAM)", value=True) | |
| max_tokens = gr.Slider(64, 4096, value=512, step=64, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty") | |
| system_prompt = gr.Textbox( | |
| label="System prompt", | |
| placeholder="Tu es un assistant utile...", | |
| lines=3, | |
| ) | |
| model_status = gr.Markdown("**Statut :** aucun modèle chargé") | |
| btn_load = gr.Button("⬇️ Charger le modèle", variant="secondary", size="sm") | |
| btn_unload = gr.Button("🗑️ Décharger", size="sm") | |
| # ── Chat ─────────────────────────────────────────── | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| label="", | |
| height=520, | |
| show_label=False, | |
| bubble_full_width=False, | |
| show_copy_button=True, | |
| avatar_images=(None, "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| placeholder="Écris ton message ici… (Entrée pour envoyer)", | |
| show_label=False, | |
| scale=9, | |
| lines=1, | |
| max_lines=6, | |
| autofocus=True, | |
| ) | |
| send_btn = gr.Button("Envoyer", variant="primary", scale=1) | |
| with gr.Row(): | |
| conv_title_input = gr.Textbox(placeholder="Renommer la conversation…", show_label=False, scale=4) | |
| btn_rename = gr.Button("✏️ Renommer", size="sm", scale=1) | |
| btn_delete = gr.Button("🗑️ Supprimer", size="sm", scale=1, variant="stop") | |
| # ── Handlers ─────────────────────────────────────────── | |
| def on_load(request: gr.Request): | |
| user = get_user_from_request(request) | |
| if user is None: | |
| return None, None, [], gr.update(samples=[]), "*Non connecté — connectez-vous via HuggingFace*" | |
| uid = user["username"] | |
| convs = db.list_conversations(uid) | |
| samples = [[fmt_conv_label(c)] for c in convs] | |
| conv_id, history, chatbot_data = None, [], [] | |
| if convs: | |
| conv_id = convs[0]["id"] | |
| msgs = db.get_messages(conv_id) | |
| history = [(m["role"], m["content"]) for m in msgs] | |
| chatbot_data = messages_to_gradio(msgs) | |
| return ( | |
| user, | |
| conv_id, | |
| history, | |
| gr.update(samples=samples), | |
| f"**Connecté :** {user['name']} (`{uid}`)", | |
| ) | |
| demo.load( | |
| on_load, | |
| inputs=None, | |
| outputs=[state_user, state_conv_id, state_history, conv_list, user_display], | |
| ) | |
| # New conversation | |
| def new_conv(user): | |
| if user is None: | |
| return None, [], [], gr.update() | |
| uid = user["username"] | |
| conv_id = db.create_conversation(uid) | |
| convs = db.list_conversations(uid) | |
| return conv_id, [], [], gr.update(samples=[[fmt_conv_label(c)] for c in convs]) | |
| btn_new_conv.click(new_conv, [state_user], [state_conv_id, state_history, chatbot, conv_list]) | |
| # Select conversation | |
| def select_conv(evt: gr.SelectData, user): | |
| if user is None: | |
| return None, [], [] | |
| convs = db.list_conversations(user["username"]) | |
| if evt.index[0] >= len(convs): | |
| return None, [], [] | |
| conv = convs[evt.index[0]] | |
| msgs = db.get_messages(conv["id"]) | |
| history = [(m["role"], m["content"]) for m in msgs] | |
| return conv["id"], history, messages_to_gradio(msgs) | |
| conv_list.select(select_conv, [state_user], [state_conv_id, state_history, chatbot]) | |
| # Load model | |
| def load_model_fn(model_id, cpu, quant4, user): | |
| if user is None: | |
| yield "⚠️ Connectez-vous d'abord." | |
| return | |
| if not model_id.strip(): | |
| yield "⚠️ Veuillez entrer un model_id valide." | |
| return | |
| try: | |
| yield f"⏳ Chargement de `{model_id}`…" | |
| model_runner.load_model(model_id.strip(), use_4bit=quant4, use_cpu=cpu) | |
| yield f"✅ **Modèle chargé :** `{model_id}`" | |
| except Exception as e: | |
| yield f"❌ Erreur : {e}" | |
| btn_load.click(load_model_fn, [model_input, use_cpu, use_4bit, state_user], [model_status]) | |
| def unload_fn(): | |
| model_runner._unload() | |
| return "**Statut :** aucun modèle chargé" | |
| btn_unload.click(unload_fn, outputs=[model_status]) | |
| # Send message | |
| def user_message(user_msg, history, conv_id, user, model_id): | |
| if user is None or not user_msg.strip(): | |
| yield history, conv_id, gr.update(value="") | |
| return | |
| uid = user["username"] | |
| if conv_id is None: | |
| conv_id = db.create_conversation(uid, db.auto_title_from_message(user_msg), model_id) | |
| else: | |
| db.update_conversation_model(conv_id, model_id) | |
| db.add_message(conv_id, "user", user_msg) | |
| history = history + [("user", user_msg)] | |
| chatbot_data = messages_to_gradio([{"role": r, "content": c} for r, c in history]) | |
| yield chatbot_data, conv_id, gr.update(value="") | |
| def bot_response(chatbot_data, history, conv_id, user, max_tok, temp, tp, rp, sysp): | |
| if user is None or conv_id is None: | |
| yield chatbot_data, history | |
| return | |
| if not model_runner.is_loaded(): | |
| chatbot_data[-1] = (chatbot_data[-1][0], "⚠️ Aucun modèle chargé.") | |
| yield chatbot_data, history | |
| return | |
| context = [{"role": r, "content": c} for r, c in history[-20:]] | |
| partial = "" | |
| for chunk in model_runner.generate_stream( | |
| messages=context, | |
| max_new_tokens=int(max_tok), | |
| temperature=float(temp), | |
| top_p=float(tp), | |
| repetition_penalty=float(rp), | |
| system_prompt=sysp, | |
| ): | |
| partial += chunk | |
| chatbot_data[-1] = (chatbot_data[-1][0], partial) | |
| yield chatbot_data, history | |
| db.add_message(conv_id, "assistant", partial) | |
| history = history + [("assistant", partial)] | |
| yield chatbot_data, history | |
| (msg_input.submit( | |
| user_message, | |
| [msg_input, state_history, state_conv_id, state_user, model_input], | |
| [chatbot, state_conv_id, msg_input], | |
| ).then( | |
| bot_response, | |
| [chatbot, state_history, state_conv_id, state_user, max_tokens, temperature, top_p, rep_penalty, system_prompt], | |
| [chatbot, state_history], | |
| )) | |
| (send_btn.click( | |
| user_message, | |
| [msg_input, state_history, state_conv_id, state_user, model_input], | |
| [chatbot, state_conv_id, msg_input], | |
| ).then( | |
| bot_response, | |
| [chatbot, state_history, state_conv_id, state_user, max_tokens, temperature, top_p, rep_penalty, system_prompt], | |
| [chatbot, state_history], | |
| )) | |
| # Rename | |
| def rename_conv(conv_id, title, user): | |
| if conv_id and user and title.strip(): | |
| db.rename_conversation(conv_id, user["username"], title.strip()) | |
| convs = db.list_conversations(user["username"]) | |
| return gr.update(samples=[[fmt_conv_label(c)] for c in convs]), gr.update(value="") | |
| return gr.update(), gr.update() | |
| btn_rename.click(rename_conv, [state_conv_id, conv_title_input, state_user], [conv_list, conv_title_input]) | |
| # Delete | |
| def delete_conv(conv_id, user): | |
| if conv_id and user: | |
| db.delete_conversation(conv_id, user["username"]) | |
| convs = db.list_conversations(user["username"]) | |
| return None, [], [], gr.update(samples=[[fmt_conv_label(c)] for c in convs]) | |
| return conv_id, gr.update(), gr.update(), gr.update() | |
| btn_delete.click(delete_conv, [state_conv_id, state_user], [state_conv_id, state_history, chatbot, conv_list]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_app() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)), | |
| share=False, | |
| ) | |