| | import gradio as gr |
| | import torch |
| | import spaces |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TextIteratorStreamer, |
| | ) |
| | from threading import Thread |
| |
|
| | |
| | MODELS = { |
| | "Llama 3.2 1B": "meta-llama/Llama-3.2-1B-Instruct", |
| | "Llama 3.2 3B": "meta-llama/Llama-3.2-3B-Instruct", |
| | } |
| |
|
| | |
| | model_cache = {} |
| | tokenizer_cache = {} |
| |
|
| |
|
| | def load_model_and_tokenizer(model_id): |
| | """Load model and tokenizer with caching.""" |
| | if model_id in model_cache: |
| | return model_cache[model_id], tokenizer_cache[model_id] |
| | |
| | dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | device_map="auto", |
| | attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa", |
| | ) |
| | |
| | model_cache[model_id] = model |
| | tokenizer_cache[model_id] = tokenizer |
| | |
| | return model, tokenizer |
| |
|
| |
|
| | @spaces.GPU(duration=120) |
| | def generate_with_assisted_decoding( |
| | message: str, |
| | history: list, |
| | model_choice: str, |
| | max_tokens: int, |
| | temperature: float, |
| | top_p: float, |
| | use_assisted_decoding: bool, |
| | ): |
| | """Generate response using assisted decoding for speed.""" |
| | |
| | model, tokenizer = load_model_and_tokenizer(MODELS[model_choice]) |
| | |
| | messages = [{"role": "system", "content": "You are a helpful assistant."}] |
| | |
| | for user_msg, assistant_msg in history: |
| | if user_msg: |
| | messages.append({"role": "user", "content": user_msg}) |
| | if assistant_msg: |
| | messages.append({"role": "assistant", "content": assistant_msg}) |
| | |
| | messages.append({"role": "user", "content": message}) |
| | |
| | input_text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
| | |
| | assistant_model = None |
| | if use_assisted_decoding and model_choice == "Llama 3.2 3B": |
| | try: |
| | assistant_model, _ = load_model_and_tokenizer(MODELS["Llama 3.2 1B"]) |
| | except Exception as e: |
| | print(f"[Warning] Could not load assistant model: {e}") |
| | assistant_model = None |
| | |
| | streamer = TextIteratorStreamer( |
| | tokenizer, |
| | skip_prompt=True, |
| | skip_special_tokens=True, |
| | ) |
| | |
| | generation_kwargs = { |
| | "input_ids": inputs["input_ids"], |
| | "attention_mask": inputs["attention_mask"], |
| | "max_new_tokens": int(max_tokens), |
| | "temperature": float(temperature), |
| | "top_p": float(top_p), |
| | "do_sample": temperature > 0.0, |
| | "streamer": streamer, |
| | "pad_token_id": tokenizer.eos_token_id, |
| | } |
| | |
| | if assistant_model is not None: |
| | generation_kwargs["assistant_model"] = assistant_model |
| | |
| | thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| | thread.start() |
| | |
| | full_response = "" |
| | for text in streamer: |
| | full_response += text |
| | yield full_response |
| | |
| | thread.join() |
| |
|
| |
|
| | def create_demo(): |
| | """Create Gradio interface.""" |
| | with gr.Blocks(title="Llama 3.2 Inference") as demo: |
| | gr.Markdown( |
| | """ |
| | # Llama 3.2 Inference - Optimized |
| | |
| | **Assisted Decoding** + **torch.compile** + **Flash Attention 2** |
| | |
| | - Assisted Decoding: 1B draft model accelerates generation (~1.3-1.5x faster) |
| | - torch.compile: JIT compilation (20-40% speedup) |
| | - Flash Attention 2: Faster attention (automatic on CUDA) |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | model_choice = gr.Dropdown( |
| | choices=list(MODELS.keys()), |
| | value="Llama 3.2 3B", |
| | label="Model", |
| | ) |
| | with gr.Column(): |
| | use_assisted = gr.Checkbox( |
| | value=True, |
| | label="Use Assisted Decoding", |
| | ) |
| | |
| | with gr.Row(): |
| | max_tokens = gr.Slider( |
| | minimum=32, |
| | maximum=2048, |
| | value=512, |
| | step=32, |
| | label="Max Tokens", |
| | ) |
| | temperature = gr.Slider( |
| | minimum=0.0, |
| | maximum=2.0, |
| | value=0.7, |
| | step=0.05, |
| | label="Temperature", |
| | ) |
| | top_p = gr.Slider( |
| | minimum=0.0, |
| | maximum=1.0, |
| | value=0.95, |
| | step=0.05, |
| | label="Top-p", |
| | ) |
| | |
| | chatbot = gr.ChatInterface( |
| | fn=generate_with_assisted_decoding, |
| | additional_inputs=[ |
| | model_choice, |
| | max_tokens, |
| | temperature, |
| | top_p, |
| | use_assisted, |
| | ], |
| | examples=[ |
| | ["What are the top 3 programming languages in 2024?"], |
| | ["Write a Python function to calculate fibonacci"], |
| | ["Explain quantum computing in simple terms"], |
| | ], |
| | ) |
| | |
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = create_demo() |
| | demo.launch() |
| |
|