policy-analysis / app_reserve.py
kaburia's picture
updated app
5d99375
import os
import uuid
import time
import json
import requests
import gradio as gr
import time
import utils.helpers as helpers
from utils.helpers import retrieve_context, log_interaction_hf, upload_log_to_hf
# ========= Config & Globals =========
with open("config.json") as f:
config = json.load(f)
DO_API_KEY = config["do_token"]
token_ = config['token']
HF_TOKEN = 'hf_' + token_
session_id = f"{int(time.time())}-{uuid.uuid4().hex[:8]}"
helpers.session_id = session_id
BASE_URL = "https://inference.do-ai.run/v1"
UPLOAD_INTERVAL = 5
# ========= Inference Utilities =========
def _auth_headers():
return {"Authorization": f"Bearer {DO_API_KEY}", "Content-Type": "application/json"}
def list_models():
try:
r = requests.get(f"{BASE_URL}/models", headers=_auth_headers(), timeout=15)
r.raise_for_status()
data = r.json().get("data", [])
ids = [m["id"] for m in data]
if ids:
return ids
except Exception as e:
print(f"⚠️ list_models failed: {e}")
return ["llama3.3-70b-instruct"]
def gradient_request(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
url = f"{BASE_URL}/chat/completions"
if not model_id:
model_id = list_models()[0]
payload = {
"model": model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
for attempt in range(3):
try:
resp = requests.post(url, headers=_auth_headers(), json=payload, timeout=30)
if resp.status_code == 404:
ids = list_models()
if model_id not in ids and ids:
payload["model"] = ids[0]
continue
resp.raise_for_status()
j = resp.json()
return j["choices"][0]["message"]["content"].strip()
except requests.HTTPError as e:
msg = getattr(e.response, "text", str(e))
raise RuntimeError(f"Inference error ({e.response.status_code}): {msg}") from e
except requests.RequestException as e:
if attempt == 2:
raise
raise RuntimeError("Exhausted retries")
def gradient_stream(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
url = f"{BASE_URL}/chat/completions"
if not model_id:
model_id = list_models()[0]
payload = {
"model": model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
}
# Create a generator that yields tokens
try:
with requests.post(url, headers=_auth_headers(), json=payload, stream=True, timeout=120) as r:
if r.status_code != 200:
try:
err_txt = r.text
except Exception:
err_txt = "<no body>"
raise RuntimeError(f"HTTP {r.status_code}: {err_txt}")
buffer = ""
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data:'):
data = decoded_line[5:].strip()
if data == '[DONE]':
break
try:
json_data = json.loads(data)
if 'choices' in json_data:
for choice in json_data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
content = choice['delta']['content']
buffer += content
yield content
except json.JSONDecodeError:
continue
if not buffer:
yield "No response received from the model."
except Exception as e:
raise RuntimeError(f"Streaming error: {str(e)}")
def gradient_complete(model_id, prompt, max_tokens=512, temperature=0.7, top_p=0.95):
url = f"{BASE_URL}/chat/completions"
payload = {
"model": model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
r = requests.post(url, headers=_auth_headers(), json=payload, timeout=60)
if r.status_code != 200:
raise RuntimeError(f"HTTP {r.status_code}: {r.text}")
j = r.json()
return j["choices"][0]["message"]["content"].strip()
# ========= Lightweight Intent Detection =========
def detect_intent(model_id, message: str) -> str:
try:
out = gradient_request(
model_id,
f"Classify as 'small_talk' or 'info_query': {message}",
max_tokens=8,
temperature=0.0,
top_p=1.0,
)
return "small_talk" if "small_talk" in out.lower() else "info_query"
except Exception as e:
print(f"⚠️ detect_intent failed: {e}")
return "info_query"
# ========= App Logic (Gradio Blocks) =========
with gr.Blocks(title="Gradient AI Chat") as demo:
# Keep a reactive turn counter in session state
turn_counter = gr.State(0)
gr.Markdown("## Gradient AI Chat")
gr.Markdown("Select a model and ask your question.")
# Model dropdown will be populated at runtime with live IDs
with gr.Row():
model_drop = gr.Dropdown(choices=[], label="Select Model")
system_msg = gr.Textbox(
value="You are a faithful assistant. Use only the provided context.",
label="System message"
)
with gr.Row():
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
# Use tuples to silence deprecation warning in current Gradio
chatbot = gr.Chatbot(height=500, type="tuples")
msg = gr.Textbox(label="Your message")
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.ClearButton([msg, chatbot])
examples = gr.Examples(
examples=[
["What are the advantages of llama3.3-70b-instruct?"],
["Explain how DeepSeek R1 Distill Llama 70B handles reasoning tasks."],
["What is the difference between llama3.3-70b-instruct and qwen2.5-32b-instruct?"],
],
inputs=[msg]
)
# --- Load models into dropdown at startup
def load_models():
ids = list_models()
default = ids[0] if ids else None
return gr.Dropdown(choices=ids, value=default)
demo.load(load_models, outputs=[model_drop])
# Optional warm-up so first user doesn't pay cold start cost
def warmup():
try:
_ = retrieve_context("warmup", p=1, threshold=0.0)
except Exception as e:
print(f"⚠️ warmup failed: {e}")
demo.load(warmup, outputs=None)
# --- Event handlers
def user(user_message, chat_history):
# Seed a new assistant message for streaming
return "", (chat_history + [[user_message, ""]])
def bot(chat_history, current_turn_count, model_id, system_message, max_tokens, temperature, top_p):
user_message = chat_history[-1][0]
# Build prompt
intent = detect_intent(model_id, user_message)
if intent == "small_talk":
full_prompt = f"[System]: Friendly chat.\n[User]: {user_message}\n[Assistant]: "
else:
try:
context = retrieve_context(user_message, p=5, threshold=0.5)
except Exception as e:
print(f"⚠️ retrieve_context failed: {e}")
context = ""
full_prompt = (
f"[System]: {system_message}\n"
"Use only the provided context. Quote verbatim; no inference.\n\n"
f"Context:\n{context}\n\nQuestion: {user_message}\n"
)
# Initialize assistant message to empty string and update chat history
chat_history[-1][1] = ""
yield chat_history, current_turn_count
# Attempt to stream the response
try:
received_any = False
for token in gradient_stream(model_id, full_prompt, max_tokens, temperature, top_p):
if token: # Skip empty tokens
received_any = True
chat_history[-1][1] += token
yield chat_history, current_turn_count
# If we didn't receive any tokens, fall back to non-streaming
if not received_any:
raise RuntimeError("Streaming returned no tokens; falling back.")
except Exception as e:
print(f"⚠️ Streaming failed: {e}")
try:
# Fall back to non-streaming
response = gradient_complete(model_id, full_prompt, max_tokens, temperature, top_p)
chat_history[-1][1] = response
yield chat_history, current_turn_count
except Exception as e2:
chat_history[-1][1] = f"⚠️ Inference failed: {e2}"
yield chat_history, current_turn_count
return
# After successful response, log and update turn counter
try:
log_interaction_hf(user_message, chat_history[-1][1])
except Exception as e:
print(f"⚠️ log_interaction_hf failed: {e}")
new_turn_count = (current_turn_count or 0) + 1
# Periodically upload logs
if new_turn_count % UPLOAD_INTERVAL == 0:
try:
upload_log_to_hf(HF_TOKEN)
except Exception as e:
print(f"❌ Log upload failed: {e}")
# Update the state with the new turn count
yield chat_history, new_turn_count
# Wiring (streaming generators supported)
msg.submit(
user,
[msg, chatbot],
[msg, chatbot],
queue=True
).then(
bot,
[chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider],
[chatbot, turn_counter],
queue=True
)
submit_btn.click(
user,
[msg, chatbot],
[msg, chatbot],
queue=True
).then(
bot,
[chatbot, turn_counter, model_drop, system_msg, max_tokens_slider, temperature_slider, top_p_slider],
[chatbot, turn_counter],
queue=True
)
if __name__ == "__main__":
# On HF Spaces, don't use share=True. Also disable API page to avoid schema churn.
demo.launch(show_api=False)