import gradio as gr import torch import spaces from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) from threading import Thread # Model configurations for Llama 3.2 MODELS = { "Llama 3.2 1B": "meta-llama/Llama-3.2-1B-Instruct", "Llama 3.2 3B": "meta-llama/Llama-3.2-3B-Instruct", } # Global model cache model_cache = {} tokenizer_cache = {} def load_model_and_tokenizer(model_id): """Load model and tokenizer with caching.""" if model_id in model_cache: return model_cache[model_id], tokenizer_cache[model_id] dtype = torch.float16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map="auto", attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa", ) model_cache[model_id] = model tokenizer_cache[model_id] = tokenizer return model, tokenizer @spaces.GPU(duration=120) def generate_with_assisted_decoding( message: str, history: list, model_choice: str, max_tokens: int, temperature: float, top_p: float, use_assisted_decoding: bool, ): """Generate response using assisted decoding for speed.""" model, tokenizer = load_model_and_tokenizer(MODELS[model_choice]) messages = [{"role": "system", "content": "You are a helpful assistant."}] for user_msg, assistant_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) assistant_model = None if use_assisted_decoding and model_choice == "Llama 3.2 3B": try: assistant_model, _ = load_model_and_tokenizer(MODELS["Llama 3.2 1B"]) except Exception as e: print(f"[Warning] Could not load assistant model: {e}") assistant_model = None streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_new_tokens": int(max_tokens), "temperature": float(temperature), "top_p": float(top_p), "do_sample": temperature > 0.0, "streamer": streamer, "pad_token_id": tokenizer.eos_token_id, } if assistant_model is not None: generation_kwargs["assistant_model"] = assistant_model thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() full_response = "" for text in streamer: full_response += text yield full_response thread.join() def create_demo(): """Create Gradio interface.""" with gr.Blocks(title="Llama 3.2 Inference") as demo: gr.Markdown( """ # Llama 3.2 Inference - Optimized **Assisted Decoding** + **torch.compile** + **Flash Attention 2** - Assisted Decoding: 1B draft model accelerates generation (~1.3-1.5x faster) - torch.compile: JIT compilation (20-40% speedup) - Flash Attention 2: Faster attention (automatic on CUDA) """ ) with gr.Row(): with gr.Column(): model_choice = gr.Dropdown( choices=list(MODELS.keys()), value="Llama 3.2 3B", label="Model", ) with gr.Column(): use_assisted = gr.Checkbox( value=True, label="Use Assisted Decoding", ) with gr.Row(): max_tokens = gr.Slider( minimum=32, maximum=2048, value=512, step=32, label="Max Tokens", ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Temperature", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p", ) chatbot = gr.ChatInterface( fn=generate_with_assisted_decoding, additional_inputs=[ model_choice, max_tokens, temperature, top_p, use_assisted, ], examples=[ ["What are the top 3 programming languages in 2024?"], ["Write a Python function to calculate fibonacci"], ["Explain quantum computing in simple terms"], ], ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()