Spaces:
Paused
Paused
| import os | |
| import uuid | |
| import time | |
| import json | |
| import requests | |
| import gradio as gr | |
| import time | |
| import utils.helpers as helpers | |
| from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf | |
| # ========= Config & Globals ========= | |
| with open("config.json") as f: | |
| config = json.load(f) | |
| DO_API_KEY = config["do_token"] | |
| token_ = config['token'] | |
| HF_TOKEN = 'hf_' + token_ | |
| session_id = f"{int(time.time())}-{uuid.uuid4().hex[:8]}" | |
| helpers.session_id = session_id | |
| BASE_URL = "https://inference.do-ai.run/v1" | |
| UPLOAD_INTERVAL = 5 | |
| # ========= Inference Utilities ========= | |
| def _auth_headers(): | |
| return {"Authorization": f"Bearer {DO_API_KEY}", "Content-Type": "application/json"} | |
| def list_models(): | |
| try: | |
| r = requests.get(f"{BASE_URL}/models", headers=_auth_headers(), timeout=15) | |
| r.raise_for_status() | |
| data = r.json().get("data", []) | |
| ids = [m["id"] for m in data] | |
| if ids: | |
| return ids | |
| except Exception as e: | |
| print(f"⚠️ list_models failed: {e}") | |
| return ["llama3.3-70b-instruct"] | |
| def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95): | |
| url = f"{BASE_URL}/chat/completions" | |
| if not model_id: | |
| model_id = list_models()[0] | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| for attempt in range(3): | |
| try: | |
| resp = requests.post(url, headers=_auth_headers(), json=payload, timeout=30) | |
| if resp.status_code == 404: | |
| ids = list_models() | |
| if model_id not in ids and ids: | |
| payload["model"] = ids[0] | |
| continue | |
| resp.raise_for_status() | |
| j = resp.json() | |
| return j["choices"][0]["message"]["content"].strip() | |
| except requests.HTTPError as e: | |
| msg = getattr(e.response, "text", str(e)) | |
| raise RuntimeError(f"Inference error ({e.response.status_code}): {msg}") from e | |
| except requests.RequestException as e: | |
| if attempt == 2: | |
| raise | |
| raise RuntimeError("Exhausted retries") | |
| def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95): | |
| url = f"{BASE_URL}/chat/completions" | |
| if not model_id: | |
| model_id = list_models()[0] | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": True, | |
| } | |
| # Create a generator that yields tokens | |
| try: | |
| with requests.post(url, headers=_auth_headers(), json=payload, stream=True, timeout=120) as r: | |
| if r.status_code != 200: | |
| try: | |
| err_txt = r.text | |
| except Exception: | |
| err_txt = "<no body>" | |
| raise RuntimeError(f"HTTP {r.status_code}: {err_txt}") | |
| buffer = "" | |
| for line in r.iter_lines(): | |
| if line: | |
| decoded_line = line.decode('utf-8') | |
| if decoded_line.startswith('data:'): | |
| data = decoded_line[5:].strip() | |
| if data == '[DONE]': | |
| break | |
| try: | |
| json_data = json.loads(data) | |
| if 'choices' in json_data: | |
| for choice in json_data['choices']: | |
| if 'delta' in choice and 'content' in choice['delta']: | |
| content = choice['delta']['content'] | |
| buffer += content | |
| yield content | |
| except json.JSONDecodeError: | |
| continue | |
| if not buffer: | |
| yield "No response received from the model." | |
| except Exception as e: | |
| raise RuntimeError(f"Streaming error: {str(e)}") | |
| def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95): | |
| url = f"{BASE_URL}/chat/completions" | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| r = requests.post(url, headers=_auth_headers(), json=payload, timeout=60) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"HTTP {r.status_code}: {r.text}") | |
| j = r.json() | |
| return j["choices"][0]["message"]["content"].strip() | |
| # ========= Lightweight Intent Detection ========= | |
| def detect_intent(model_id, message: str) -> str: | |
| try: | |
| out = gradient_request( | |
| model_id, | |
| f"Classify as 'small_talk' or 'info_query': {message}", | |
| max_tokens=8, | |
| temperature=0.0, | |
| top_p=1.0, | |
| ) | |
| return "small_talk" if "small_talk" in out.lower() else "info_query" | |
| except Exception as e: | |
| print(f"⚠️ detect_intent failed: {e}") | |
| return "info_query" | |
| # ========= App Logic (Gradio Blocks) ========= | |
| with gr.Blocks(title="Gradient AI Chat") as demo: | |
| # Keep a reactive turn counter in session state | |
| turn_counter = gr.State(0) | |
| gr.Markdown("## Gradient AI Chat") | |
| gr.Markdown("Select a model and ask your question.") | |
| # Model dropdown will be populated at runtime with live IDs | |
| with gr.Row(): | |
| model_drop = gr.Dropdown(choices=[], label="Select Model") | |
| system_msg = gr.Textbox( | |
| value="You are a faithful assistant. Use only the provided context.", | |
| label="System message" | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens") | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p") | |
| # Use tuples to silence deprecation warning in current Gradio | |
| chatbot = gr.Chatbot(height=500, type="tuples") | |
| msg = gr.Textbox(label="Your message") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.ClearButton([msg, chatbot]) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["What are the advantages of llama3.3-70b-instruct?"], | |
| ["Explain how DeepSeek R1 Distill Llama 70B handles reasoning tasks."], | |
| ["What is the difference between llama3.3-70b-instruct and qwen2.5-32b-instruct?"], | |
| ], | |
| inputs=[msg] | |
| ) | |
| # --- Load models into dropdown at startup | |
| def load_models(): | |
| ids = list_models() | |
| default = ids[0] if ids else None | |
| return gr.Dropdown(choices=ids, value=default) | |
| demo.load(load_models, outputs=[model_drop]) | |
| # Optional warm-up so first user doesn't pay cold start cost | |
| def warmup(): | |
| try: | |
| _ = retrieve_context("warmup", p=1, threshold=0.0) | |
| except Exception as e: | |
| print(f"⚠️ warmup failed: {e}") | |
| demo.load(warmup, outputs=None) | |
| # --- Event handlers | |
| def user(user_message, chat_history): | |
| # Seed a new assistant message for streaming | |
| return "", (chat_history + [[user_message, ""]]) | |
| def bot(chat_history, current_turn_count, model_id, system_message, max_tokens, temperature, top_p): | |
| user_message = chat_history[-1][0] | |
| # Build prompt | |
| intent = detect_intent(model_id, user_message) | |
| if intent == "small_talk": | |
| full_prompt = f"[System]: Friendly chat.\n[User]: {user_message}\n[Assistant]: " | |
| else: | |
| try: | |
| context = retrieve_context(user_message, p=5, threshold=0.5) | |
| except Exception as e: | |
| print(f"⚠️ retrieve_context failed: {e}") | |
| context = "" | |
| full_prompt = ( | |
| f"[System]: {system_message}\n" | |
| "Use only the provided context. Quote verbatim; no inference.\n\n" | |
| f"Context:\n{context}\n\nQuestion: {user_message}\n" | |
| ) | |
| # Initialize assistant message to empty string and update chat history | |
| chat_history[-1][1] = "" | |
| yield chat_history, current_turn_count | |
| # Attempt to stream the response | |
| try: | |
| received_any = False | |
| for token in gradient_stream(model_id, full_prompt, max_tokens, temperature, top_p): | |
| if token: # Skip empty tokens | |
| received_any = True | |
| chat_history[-1][1] += token | |
| yield chat_history, current_turn_count | |
| # If we didn't receive any tokens, fall back to non-streaming | |
| if not received_any: | |
| raise RuntimeError("Streaming returned no tokens; falling back.") | |
| except Exception as e: | |
| print(f"⚠️ Streaming failed: {e}") | |
| try: | |
| # Fall back to non-streaming | |
| response = gradient_complete(model_id, full_prompt, max_tokens, temperature, top_p) | |
| chat_history[-1][1] = response | |
| yield chat_history, current_turn_count | |
| except Exception as e2: | |
| chat_history[-1][1] = f"⚠️ Inference failed: {e2}" | |
| yield chat_history, current_turn_count | |
| return | |
| # After successful response, log and update turn counter | |
| try: | |
| log_interaction_hf(user_message, chat_history[-1][1]) | |
| except Exception as e: | |
| print(f"⚠️ log_interaction_hf failed: {e}") | |
| new_turn_count = (current_turn_count or 0) + 1 | |
| # Periodically upload logs | |
| if new_turn_count % UPLOAD_INTERVAL == 0: | |
| try: | |
| upload_log_to_hf(HF_TOKEN) | |
| except Exception as e: | |
| print(f"❌ Log upload failed: {e}") | |
| # Update the state with the new turn count | |
| yield chat_history, new_turn_count | |
| # Wiring (streaming generators supported) | |
| msg.submit( | |
| user, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=True | |
| ).then( | |
| bot, | |
| [chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider], | |
| [chatbot, turn_counter], | |
| queue=True | |
| ) | |
| submit_btn.click( | |
| user, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=True | |
| ).then( | |
| bot, | |
| [chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider], | |
| [chatbot, turn_counter], | |
| queue=True | |
| ) | |
| if __name__ == "__main__": | |
| # On HF Spaces, don't use share=True. Also disable API page to avoid schema churn. | |
| demo.launch(show_api=False) |