| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, StoppingCriteria |
| import torch |
| import torch.nn as nn |
| import threading |
| from optimum.onnxruntime import ORTModelForCausalLM |
|
|
| |
| |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| |
| |
| |
|
|
| print("Loading model optimized for CPU...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
| |
| |
| |
| |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
|
|
| |
| model = torch.quantization.quantize_dynamic( |
| model, |
| {nn.Linear}, |
| dtype=torch.qint8 |
| ) |
|
|
| print("✓ Model loaded with dynamic quantization for CPU") |
|
|
| class StopIfFlagSet(StoppingCriteria): |
| """StoppingCriteria for Hugging Face generation to interrupt if stop_event is set.""" |
| def __init__(self, stop_event): |
| self.stop_event = stop_event |
|
|
| def __call__(self, input_ids, scores, **kwargs): |
| return self.stop_event.is_set() |
|
|
| def generate_code_stream(prompt, max_new_tokens=100, temperature=0.7, use_sampling=True): |
| """ |
| Generator function to stream code completions token-by-token. |
| CPU-optimized with caching and efficient parameters. |
| """ |
| inputs = tokenizer(prompt, return_tensors="pt") |
| |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) |
| stop_event = threading.Event() |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": max_new_tokens, |
| "pad_token_id": tokenizer.eos_token_id, |
| "eos_token_id": tokenizer.eos_token_id, |
| "streamer": streamer, |
| "stopping_criteria": [StopIfFlagSet(stop_event)], |
| "use_cache": True, |
| "num_beams": 1 |
| } |
| |
| |
| if use_sampling: |
| gen_kwargs.update({ |
| "do_sample": True, |
| "temperature": temperature, |
| "top_p": 0.95, |
| "top_k": 50 |
| }) |
| else: |
| |
| gen_kwargs["do_sample"] = False |
|
|
| |
| thread = threading.Thread( |
| target=model.generate, |
| args=(inputs["input_ids"],), |
| kwargs=gen_kwargs |
| ) |
| thread.start() |
|
|
| decoded_so_far = "" |
| try: |
| for token in streamer: |
| |
| if stop_event.is_set(): |
| break |
| decoded_so_far += token |
| yield {"prompt": prompt, "partial_completion": decoded_so_far, "done": False} |
| finally: |
| |
| stop_event.set() |
| thread.join() |
| yield {"prompt": prompt, "partial_completion": decoded_so_far, "done": True} |
| def chat_stream(message, history, temperature=0.7, max_new_tokens=200): |
| try: |
| |
| def build_prompt(history, message): |
| prompt = "<|system|>\nYou are a helpful assistant.\n" |
| for user_msg, bot_msg in history: |
| prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{bot_msg}\n" |
| prompt += f"<|user|>\n{message}\n<|assistant|>\n" |
| return prompt |
|
|
| conversation = build_prompt(history, message) |
|
|
| |
| inputs = tokenizer(conversation, return_tensors="pt", return_attention_mask=True) |
|
|
| |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) |
| stop_event = threading.Event() |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": max_new_tokens, |
| "pad_token_id": tokenizer.eos_token_id, |
| "eos_token_id": tokenizer.eos_token_id, |
| "streamer": streamer, |
| "stopping_criteria": [StopIfFlagSet(stop_event)], |
| "use_cache": True, |
| "do_sample": True, |
| "temperature": temperature, |
| "top_p": 0.95, |
| "num_beams": 1, |
| |
| "stop_strings": ["<|user|>", "<|system|>", "Human:"], |
| "tokenizer": tokenizer |
| } |
|
|
| |
| thread = threading.Thread( |
| target=model.generate, |
| args=(inputs["input_ids"],), |
| kwargs=gen_kwargs |
| ) |
| thread.start() |
|
|
| full_response = "" |
|
|
| try: |
| for token in streamer: |
| if stop_event.is_set(): |
| break |
| full_response += token |
|
|
| |
| if "<|assistant|>" in full_response: |
| response = full_response.split("<|assistant|>")[-1].strip() |
| else: |
| response = "" |
|
|
| yield {"response": response, "done": False} |
|
|
| finally: |
| stop_event.set() |
| thread.join() |
|
|
| |
| if "<|assistant|>" in full_response: |
| response = full_response.split("<|assistant|>")[-1] |
| else: |
| response = full_response |
|
|
| |
| for marker in ["<|user|>", "<|system|>", "Human:"]: |
| if marker in response: |
| response = response.split(marker)[0] |
|
|
| response = response.strip() |
|
|
| if not response: |
| response = "Hello! How can I assist you today?" |
|
|
| yield {"response": response, "done": True} |
|
|
| except Exception as e: |
| yield {"response": f"Error: {str(e)}", "done": True} |
|
|
|
|
| |
| with gr.Blocks(title="Optimized Phi-2 Streaming Code Generator") as demo: |
| gr.Markdown( |
| """ |
| # Phi-2 (2.7B) - CPU-Optimized Streaming Code Generator |
| |
| **CPU optimizations enabled:** |
| - Dynamic INT8 quantization (~2x faster) |
| - KV cache enabled |
| - Low memory usage mode |
| - Optimized generation parameters |
| |
| **For even better CPU performance, consider:** |
| - Using a smaller model (Phi-1.5, TinyLlama) |
| - ONNX Runtime conversion |
| - Intel OpenVINO toolkit |
| """ |
| ) |
| |
| gr.Markdown("**Running on:** CPU (Quantized)") |
|
|
| with gr.Tab("Code Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| code_input = gr.Textbox( |
| label="Code Prompt", |
| placeholder="Enter your code prompt...", |
| lines=5, |
| value="def fibonacci(n):" |
| ) |
| max_new_tokens_slider = gr.Slider( |
| minimum=1, |
| maximum=500, |
| value=30, |
| step=1, |
| label="Max New Tokens" |
| ) |
| |
| with gr.Row(): |
| use_sampling_checkbox = gr.Checkbox( |
| label="Use Sampling (slower but more creative)", |
| value=False |
| ) |
| temperature_slider = gr.Slider( |
| minimum=0.1, |
| maximum=1.5, |
| value=0.7, |
| step=0.1, |
| label="Temperature (only if sampling enabled)" |
| ) |
| |
| generate_btn = gr.Button("Generate", variant="primary") |
|
|
| with gr.Column(): |
| output = gr.JSON(label="Streaming Output") |
| |
| gr.Markdown( |
| """ |
| ### CPU Speed Tips: |
| - **Fastest:** Disable sampling, use 20-30 tokens |
| - **Balanced:** Disable sampling, use 30-50 tokens |
| - **For better speed:** Consider switching to Phi-1.5 (1.3B) or TinyLlama (1.1B) |
| |
| ### Current optimizations: |
| - ✓ INT8 dynamic quantization (~2x speedup) |
| - ✓ KV caching enabled |
| - ✓ Greedy decoding by default |
| """ |
| ) |
|
|
| generate_btn.click( |
| fn=generate_code_stream, |
| inputs=[code_input, max_new_tokens_slider, temperature_slider, use_sampling_checkbox], |
| outputs=output |
| ) |
| |
| with gr.Tab("Chat"): |
| gr.Markdown("## Chat with the Model") |
| chatbot = gr.Chatbot(label="Conversation") |
| msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") |
| chat_temperature = gr.Slider( |
| minimum=0.1, |
| maximum=1.5, |
| value=0.7, |
| step=0.1, |
| label="Temperature" |
| ) |
| chat_max_tokens = gr.Slider( |
| minimum=50, |
| maximum=500, |
| value=200, |
| step=10, |
| label="Max Tokens" |
| ) |
| clear = gr.Button("Clear History") |
| |
| def respond_stream(message, chat_history, temp, max_tokens): |
| |
| chat_history.append([message, ""]) |
| |
| |
| for chunk in chat_stream(message, chat_history[:-1], temp, max_tokens): |
| if chunk["done"]: |
| |
| chat_history[-1][1] = chunk["response"] |
| yield "", chat_history |
| break |
| else: |
| |
| chat_history[-1][1] = chunk["response"] |
| yield "", chat_history |
| |
| msg.submit(respond_stream, [msg, chatbot, chat_temperature, chat_max_tokens], [msg, chatbot]) |
| clear.click(lambda: None, None, chatbot, queue=False) |
|
|
| if __name__ == "__main__": |
| demo.launch() |