Spaces:
Paused
Paused
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import time | |
| import json | |
| # βββ Phase Configuration βββ | |
| PHASE = "Phase 2b: INT4-NF4 Quantization (ZeroGPU)" | |
| MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" | |
| MODEL_CONFIG = { | |
| "phase": PHASE, | |
| "model_name": MODEL_NAME, | |
| "torch_dtype": "float16", | |
| "quantization": "int4-nf4", | |
| "optimization": "bitsandbytes-nf4-double-quant", | |
| "hardware": "zero-a10g", | |
| "max_new_tokens": 512, | |
| "temperature": 0.7, | |
| } | |
| # βββ 4-bit NF4 Quantization Config βββ | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| # βββ Load model and tokenizer βββ | |
| print("Loading model with INT4-NF4 quantization...", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| print("Model loaded successfully with INT4-NF4 quantization!", flush=True) | |
| def generate_response(message, history_tuples=None): | |
| """Core generation logic, returns response + metrics.""" | |
| # No model.to("cuda") needed β device_map="auto" already placed model on GPU | |
| messages = [] | |
| if history_tuples: | |
| for user_msg, assistant_msg in history_tuples: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| # apply_chat_template may return a tensor or BatchEncoding depending on version | |
| if hasattr(input_ids, "input_ids"): | |
| input_ids = input_ids.input_ids | |
| input_ids = input_ids.to(model.device) | |
| input_tokens = input_ids.shape[1] | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=MODEL_CONFIG["max_new_tokens"], | |
| temperature=MODEL_CONFIG["temperature"], | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| inference_time = time.time() - start_time | |
| output_tokens = outputs.shape[1] - input_tokens | |
| response = tokenizer.decode(outputs[0][input_tokens:], skip_special_tokens=True) | |
| tokens_per_sec = round(output_tokens / inference_time, 2) if inference_time > 0 else 0 | |
| return { | |
| "response": response, | |
| "inference_time_s": round(inference_time, 2), | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "tokens_per_sec": tokens_per_sec, | |
| "model_config": MODEL_CONFIG, | |
| } | |
| def parse_history(history): | |
| """Convert Gradio 5 history format to tuples.""" | |
| if not history: | |
| return None | |
| tuples = [] | |
| i = 0 | |
| while i < len(history): | |
| item = history[i] | |
| if isinstance(item, dict): | |
| if item.get("role") == "user": | |
| user_msg = item.get("content", "") | |
| asst_msg = "" | |
| if i + 1 < len(history): | |
| next_item = history[i + 1] | |
| if isinstance(next_item, dict) and next_item.get("role") == "assistant": | |
| asst_msg = next_item.get("content", "") | |
| i += 1 | |
| tuples.append((user_msg, asst_msg)) | |
| elif isinstance(item, (list, tuple)) and len(item) == 2: | |
| tuples.append(tuple(item)) | |
| i += 1 | |
| return tuples if tuples else None | |
| # βββ Gradio Chat (for HF Spaces UI) βββ | |
| def chat(message, history): | |
| history_tuples = parse_history(history) | |
| result = generate_response(message, history_tuples) | |
| timing = f"\n\n---\n*Inference: {result['inference_time_s']}s | {result['tokens_per_sec']} t/s | INT4-NF4 quantized*" | |
| return result["response"] + timing | |
| # βββ API Endpoint (for React app + benchmark) βββ | |
| def api_chat(message, history_json="[]"): | |
| try: | |
| if not history_json or history_json.strip() == "": | |
| history_json = "[]" | |
| history = json.loads(history_json) if isinstance(history_json, str) else history_json | |
| history_tuples = [tuple(h) for h in history] if history else None | |
| result = generate_response(message, history_tuples) | |
| return json.dumps(result) | |
| except Exception as e: | |
| import traceback | |
| return json.dumps({"error": str(e), "traceback": traceback.format_exc()}) | |
| # βββ Build Gradio App βββ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Phi-3 Mini Chatbot ({PHASE})") | |
| gr.Markdown("Chat UI + API endpoint for benchmarking | INT4-NF4 quantized with bitsandbytes") | |
| with gr.Tab("Chat"): | |
| chatbot = gr.ChatInterface(fn=chat) | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ### API Endpoint | |
| **Call `/gradio_api/call/api_chat`** (Gradio 5 SSE format): | |
| ``` | |
| POST /gradio_api/call/api_chat | |
| {"data": ["your question", "[]"]} | |
| β returns {"event_id": "..."} | |
| GET /gradio_api/call/api_chat/{event_id} | |
| β SSE stream with data: [json_result] | |
| ``` | |
| """) | |
| msg_input = gr.Textbox(label="Message", placeholder="Type your question...") | |
| history_input = gr.Textbox(label="History (JSON)", value="[]", visible=False) | |
| api_output = gr.Textbox(label="API Response (JSON)", lines=10) | |
| api_btn = gr.Button("Call API") | |
| api_btn.click( | |
| fn=api_chat, | |
| inputs=[msg_input, history_input], | |
| outputs=api_output, | |
| api_name="api_chat", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |