v1-chat-1 / app.py
AB498's picture
.
592332a
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria
import torch
import torch.nn as nn
import threading
from optimum.onnxruntime import ORTModelForCausalLM
# CPU-optimized model loading
# model_name = "microsoft/phi-2"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# model_name = "microsoft/phi-1_5"
# model_name = "distilgpt2"
# model_name = "cerebras/Cerebras-GPT-111M"
print("Loading model optimized for CPU...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# model = ORTModelForCausalLM.from_pretrained(
# model_name,
# export=True # Converts to ONNX format
# )
# CPU Optimization: Dynamic quantization (2x faster inference)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True # Reduces memory footprint
)
# Apply dynamic quantization to Linear layers (significant CPU speedup)
model = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Quantize all Linear layers
dtype=torch.qint8 # 8-bit integer quantization
)
print("✓ Model loaded with dynamic quantization for CPU")
class StopIfFlagSet(StoppingCriteria):
"""StoppingCriteria for Hugging Face generation to interrupt if stop_event is set."""
def __init__(self, stop_event):
self.stop_event = stop_event
def __call__(self, input_ids, scores, **kwargs):
return self.stop_event.is_set()
def generate_code_stream(prompt, max_new_tokens=100, temperature=0.7, use_sampling=True):
"""
Generator function to stream code completions token-by-token.
CPU-optimized with caching and efficient parameters.
"""
inputs = tokenizer(prompt, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
stop_event = threading.Event()
# CPU-optimized generation kwargs
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"streamer": streamer,
"stopping_criteria": [StopIfFlagSet(stop_event)],
"use_cache": True, # Enable KV cache (faster)
"num_beams": 1 # Beam search is slow on CPU
}
# Add sampling parameters only if sampling is enabled
if use_sampling:
gen_kwargs.update({
"do_sample": True,
"temperature": temperature,
"top_p": 0.95,
"top_k": 50 # Limits sampling pool for speed
})
else:
# Greedy decoding (faster)
gen_kwargs["do_sample"] = False
# Start generation in a separate thread
thread = threading.Thread(
target=model.generate,
args=(inputs["input_ids"],),
kwargs=gen_kwargs
)
thread.start()
decoded_so_far = ""
try:
for token in streamer:
# Stop if the client disconnected
if stop_event.is_set():
break
decoded_so_far += token
yield {"prompt": prompt, "partial_completion": decoded_so_far, "done": False}
finally:
# Signal generation thread to stop
stop_event.set()
thread.join()
yield {"prompt": prompt, "partial_completion": decoded_so_far, "done": True}
def chat_stream(message, history, temperature=0.7, max_new_tokens=200):
try:
# Build ChatML-style prompt (Phi-2 expects this format)
def build_prompt(history, message):
prompt = "<|system|>\nYou are a helpful assistant.\n"
for user_msg, bot_msg in history:
prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{bot_msg}\n"
prompt += f"<|user|>\n{message}\n<|assistant|>\n"
return prompt
conversation = build_prompt(history, message)
# Tokenize input
inputs = tokenizer(conversation, return_tensors="pt", return_attention_mask=True)
# Setup streaming
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
stop_event = threading.Event()
# Generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"streamer": streamer,
"stopping_criteria": [StopIfFlagSet(stop_event)],
"use_cache": True,
"do_sample": True,
"temperature": temperature,
"top_p": 0.95,
"num_beams": 1,
# IMPORTANT: prevent model from continuing conversation
"stop_strings": ["<|user|>", "<|system|>", "Human:"],
"tokenizer": tokenizer # REQUIRED
}
# Start generation in background
thread = threading.Thread(
target=model.generate,
args=(inputs["input_ids"],),
kwargs=gen_kwargs
)
thread.start()
full_response = ""
try:
for token in streamer:
if stop_event.is_set():
break
full_response += token
# Extract assistant part: anything after last "<|assistant|>"
if "<|assistant|>" in full_response:
response = full_response.split("<|assistant|>")[-1].strip()
else:
response = ""
yield {"response": response, "done": False}
finally:
stop_event.set()
thread.join()
# Final cleanup
if "<|assistant|>" in full_response:
response = full_response.split("<|assistant|>")[-1]
else:
response = full_response
# Trim at first sign of model starting new turn
for marker in ["<|user|>", "<|system|>", "Human:"]:
if marker in response:
response = response.split(marker)[0]
response = response.strip()
if not response:
response = "Hello! How can I assist you today?"
yield {"response": response, "done": True}
except Exception as e:
yield {"response": f"Error: {str(e)}", "done": True}
# --- Gradio App ---
with gr.Blocks(title="Optimized Phi-2 Streaming Code Generator") as demo:
gr.Markdown(
"""
# Phi-2 (2.7B) - CPU-Optimized Streaming Code Generator
**CPU optimizations enabled:**
- Dynamic INT8 quantization (~2x faster)
- KV cache enabled
- Low memory usage mode
- Optimized generation parameters
**For even better CPU performance, consider:**
- Using a smaller model (Phi-1.5, TinyLlama)
- ONNX Runtime conversion
- Intel OpenVINO toolkit
"""
)
gr.Markdown("**Running on:** CPU (Quantized)")
with gr.Tab("Code Generation"):
with gr.Row():
with gr.Column():
code_input = gr.Textbox(
label="Code Prompt",
placeholder="Enter your code prompt...",
lines=5,
value="def fibonacci(n):"
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=500,
value=30, # Lower default for CPU
step=1,
label="Max New Tokens"
)
with gr.Row():
use_sampling_checkbox = gr.Checkbox(
label="Use Sampling (slower but more creative)",
value=False # Default to greedy for speed
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature (only if sampling enabled)"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.JSON(label="Streaming Output")
gr.Markdown(
"""
### CPU Speed Tips:
- **Fastest:** Disable sampling, use 20-30 tokens
- **Balanced:** Disable sampling, use 30-50 tokens
- **For better speed:** Consider switching to Phi-1.5 (1.3B) or TinyLlama (1.1B)
### Current optimizations:
- ✓ INT8 dynamic quantization (~2x speedup)
- ✓ KV caching enabled
- ✓ Greedy decoding by default
"""
)
generate_btn.click(
fn=generate_code_stream,
inputs=[code_input, max_new_tokens_slider, temperature_slider, use_sampling_checkbox],
outputs=output
)
with gr.Tab("Chat"):
gr.Markdown("## Chat with the Model")
chatbot = gr.Chatbot(label="Conversation")
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
chat_temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
chat_max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Max Tokens"
)
clear = gr.Button("Clear History")
def respond_stream(message, chat_history, temp, max_tokens):
# Add user message to history
chat_history.append([message, ""])
# Stream the response
for chunk in chat_stream(message, chat_history[:-1], temp, max_tokens):
if chunk["done"]:
# Final response
chat_history[-1][1] = chunk["response"]
yield "", chat_history
break
else:
# Partial response
chat_history[-1][1] = chunk["response"]
yield "", chat_history
msg.submit(respond_stream, [msg, chatbot, chat_temperature, chat_max_tokens], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()