phi3-mini-chat / app_phase2.py
SahilRS's picture
Upload app_phase2.py with huggingface_hub
90d3373 verified
Raw
History Blame Contribute Delete
5.84 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
import json
# ─── Phase Configuration ───
PHASE = "Phase 2b: INT4-NF4 Quantization (ZeroGPU)"
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
MODEL_CONFIG = {
"phase": PHASE,
"model_name": MODEL_NAME,
"torch_dtype": "float16",
"quantization": "int4-nf4",
"optimization": "bitsandbytes-nf4-double-quant",
"hardware": "zero-a10g",
"max_new_tokens": 512,
"temperature": 0.7,
}
# ─── 4-bit NF4 Quantization Config ───
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# ─── Load model and tokenizer ───
print("Loading model with INT4-NF4 quantization...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
print("Model loaded successfully with INT4-NF4 quantization!", flush=True)
@spaces.GPU
def generate_response(message, history_tuples=None):
"""Core generation logic, returns response + metrics."""
# No model.to("cuda") needed β€” device_map="auto" already placed model on GPU
messages = []
if history_tuples:
for user_msg, assistant_msg in history_tuples:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
)
# apply_chat_template may return a tensor or BatchEncoding depending on version
if hasattr(input_ids, "input_ids"):
input_ids = input_ids.input_ids
input_ids = input_ids.to(model.device)
input_tokens = input_ids.shape[1]
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=MODEL_CONFIG["max_new_tokens"],
temperature=MODEL_CONFIG["temperature"],
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
inference_time = time.time() - start_time
output_tokens = outputs.shape[1] - input_tokens
response = tokenizer.decode(outputs[0][input_tokens:], skip_special_tokens=True)
tokens_per_sec = round(output_tokens / inference_time, 2) if inference_time > 0 else 0
return {
"response": response,
"inference_time_s": round(inference_time, 2),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"tokens_per_sec": tokens_per_sec,
"model_config": MODEL_CONFIG,
}
def parse_history(history):
"""Convert Gradio 5 history format to tuples."""
if not history:
return None
tuples = []
i = 0
while i < len(history):
item = history[i]
if isinstance(item, dict):
if item.get("role") == "user":
user_msg = item.get("content", "")
asst_msg = ""
if i + 1 < len(history):
next_item = history[i + 1]
if isinstance(next_item, dict) and next_item.get("role") == "assistant":
asst_msg = next_item.get("content", "")
i += 1
tuples.append((user_msg, asst_msg))
elif isinstance(item, (list, tuple)) and len(item) == 2:
tuples.append(tuple(item))
i += 1
return tuples if tuples else None
# ─── Gradio Chat (for HF Spaces UI) ───
def chat(message, history):
history_tuples = parse_history(history)
result = generate_response(message, history_tuples)
timing = f"\n\n---\n*Inference: {result['inference_time_s']}s | {result['tokens_per_sec']} t/s | INT4-NF4 quantized*"
return result["response"] + timing
# ─── API Endpoint (for React app + benchmark) ───
def api_chat(message, history_json="[]"):
try:
if not history_json or history_json.strip() == "":
history_json = "[]"
history = json.loads(history_json) if isinstance(history_json, str) else history_json
history_tuples = [tuple(h) for h in history] if history else None
result = generate_response(message, history_tuples)
return json.dumps(result)
except Exception as e:
import traceback
return json.dumps({"error": str(e), "traceback": traceback.format_exc()})
# ─── Build Gradio App ───
with gr.Blocks() as demo:
gr.Markdown(f"# Phi-3 Mini Chatbot ({PHASE})")
gr.Markdown("Chat UI + API endpoint for benchmarking | INT4-NF4 quantized with bitsandbytes")
with gr.Tab("Chat"):
chatbot = gr.ChatInterface(fn=chat)
with gr.Tab("API"):
gr.Markdown("""
### API Endpoint
**Call `/gradio_api/call/api_chat`** (Gradio 5 SSE format):
```
POST /gradio_api/call/api_chat
{"data": ["your question", "[]"]}
β†’ returns {"event_id": "..."}
GET /gradio_api/call/api_chat/{event_id}
β†’ SSE stream with data: [json_result]
```
""")
msg_input = gr.Textbox(label="Message", placeholder="Type your question...")
history_input = gr.Textbox(label="History (JSON)", value="[]", visible=False)
api_output = gr.Textbox(label="API Response (JSON)", lines=10)
api_btn = gr.Button("Call API")
api_btn.click(
fn=api_chat,
inputs=[msg_input, history_input],
outputs=api_output,
api_name="api_chat",
)
if __name__ == "__main__":
demo.launch()