import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer import time import json # ─── Phase Configuration (update per phase) ─── PHASE = "Phase 1: Baseline (ZeroGPU)" MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" MODEL_CONFIG = { "phase": PHASE, "model_name": MODEL_NAME, "torch_dtype": "float16", "quantization": "none", "optimization": "none", "hardware": "zero-a10g", "max_new_tokens": 512, "temperature": 0.7, } # ─── Load model and tokenizer ─── print("Loading model...", flush=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) print("Model loaded successfully!", flush=True) @spaces.GPU def generate_response(message, history_tuples=None): """Core generation logic, returns response + metrics.""" # Move model to GPU (ZeroGPU provides GPU only inside @spaces.GPU) model.to("cuda") messages = [] if history_tuples: for user_msg, assistant_msg in history_tuples: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ) # apply_chat_template may return a tensor or BatchEncoding depending on version if hasattr(input_ids, "input_ids"): input_ids = input_ids.input_ids input_ids = input_ids.to("cuda") input_tokens = input_ids.shape[1] start_time = time.time() with torch.no_grad(): outputs = model.generate( input_ids, max_new_tokens=MODEL_CONFIG["max_new_tokens"], temperature=MODEL_CONFIG["temperature"], do_sample=True, pad_token_id=tokenizer.eos_token_id, ) inference_time = time.time() - start_time output_tokens = outputs.shape[1] - input_tokens response = tokenizer.decode(outputs[0][input_tokens:], skip_special_tokens=True) tokens_per_sec = round(output_tokens / inference_time, 2) if inference_time > 0 else 0 return { "response": response, "inference_time_s": round(inference_time, 2), "input_tokens": input_tokens, "output_tokens": output_tokens, "tokens_per_sec": tokens_per_sec, "model_config": MODEL_CONFIG, } def parse_history(history): """Convert Gradio 5 history format to tuples.""" if not history: return None tuples = [] i = 0 while i < len(history): item = history[i] if isinstance(item, dict): if item.get("role") == "user": user_msg = item.get("content", "") asst_msg = "" if i + 1 < len(history): next_item = history[i + 1] if isinstance(next_item, dict) and next_item.get("role") == "assistant": asst_msg = next_item.get("content", "") i += 1 tuples.append((user_msg, asst_msg)) elif isinstance(item, (list, tuple)) and len(item) == 2: tuples.append(tuple(item)) i += 1 return tuples if tuples else None # ─── Gradio Chat (for HF Spaces UI) ─── def chat(message, history): history_tuples = parse_history(history) result = generate_response(message, history_tuples) timing = f"\n\n---\n*Inference: {result['inference_time_s']}s | {result['tokens_per_sec']} t/s*" return result["response"] + timing # ─── API Endpoint (for React app + benchmark) ─── def api_chat(message, history_json="[]"): try: if not history_json or history_json.strip() == "": history_json = "[]" history = json.loads(history_json) if isinstance(history_json, str) else history_json history_tuples = [tuple(h) for h in history] if history else None result = generate_response(message, history_tuples) return json.dumps(result) except Exception as e: import traceback return json.dumps({"error": str(e), "traceback": traceback.format_exc()}) # ─── Build Gradio App ─── with gr.Blocks() as demo: gr.Markdown(f"# Phi-3 Mini Chatbot ({PHASE})") gr.Markdown("Chat UI + API endpoint for benchmarking") with gr.Tab("Chat"): chatbot = gr.ChatInterface(fn=chat) with gr.Tab("API"): gr.Markdown(""" ### API Endpoint **Call `/gradio_api/call/api_chat`** (Gradio 5 SSE format): ``` POST /gradio_api/call/api_chat {"data": ["your question", "[]"]} → returns {"event_id": "..."} GET /gradio_api/call/api_chat/{event_id} → SSE stream with data: [json_result] ``` """) msg_input = gr.Textbox(label="Message", placeholder="Type your question...") history_input = gr.Textbox(label="History (JSON)", value="[]", visible=False) api_output = gr.Textbox(label="API Response (JSON)", lines=10) api_btn = gr.Button("Call API") api_btn.click( fn=api_chat, inputs=[msg_input, history_input], outputs=api_output, api_name="api_chat", ) if __name__ == "__main__": demo.launch()