chmielvu's picture
Upload app.py with huggingface_hub
d0e7b23 verified
import gradio as gr
import torch
import spaces
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
from threading import Thread
# Model configurations for Llama 3.2
MODELS = {
"Llama 3.2 1B": "meta-llama/Llama-3.2-1B-Instruct",
"Llama 3.2 3B": "meta-llama/Llama-3.2-3B-Instruct",
}
# Global model cache
model_cache = {}
tokenizer_cache = {}
def load_model_and_tokenizer(model_id):
"""Load model and tokenizer with caching."""
if model_id in model_cache:
return model_cache[model_id], tokenizer_cache[model_id]
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto",
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
)
model_cache[model_id] = model
tokenizer_cache[model_id] = tokenizer
return model, tokenizer
@spaces.GPU(duration=120)
def generate_with_assisted_decoding(
message: str,
history: list,
model_choice: str,
max_tokens: int,
temperature: float,
top_p: float,
use_assisted_decoding: bool,
):
"""Generate response using assisted decoding for speed."""
model, tokenizer = load_model_and_tokenizer(MODELS[model_choice])
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
assistant_model = None
if use_assisted_decoding and model_choice == "Llama 3.2 3B":
try:
assistant_model, _ = load_model_and_tokenizer(MODELS["Llama 3.2 1B"])
except Exception as e:
print(f"[Warning] Could not load assistant model: {e}")
assistant_model = None
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": temperature > 0.0,
"streamer": streamer,
"pad_token_id": tokenizer.eos_token_id,
}
if assistant_model is not None:
generation_kwargs["assistant_model"] = assistant_model
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_response = ""
for text in streamer:
full_response += text
yield full_response
thread.join()
def create_demo():
"""Create Gradio interface."""
with gr.Blocks(title="Llama 3.2 Inference") as demo:
gr.Markdown(
"""
# Llama 3.2 Inference - Optimized
**Assisted Decoding** + **torch.compile** + **Flash Attention 2**
- Assisted Decoding: 1B draft model accelerates generation (~1.3-1.5x faster)
- torch.compile: JIT compilation (20-40% speedup)
- Flash Attention 2: Faster attention (automatic on CUDA)
"""
)
with gr.Row():
with gr.Column():
model_choice = gr.Dropdown(
choices=list(MODELS.keys()),
value="Llama 3.2 3B",
label="Model",
)
with gr.Column():
use_assisted = gr.Checkbox(
value=True,
label="Use Assisted Decoding",
)
with gr.Row():
max_tokens = gr.Slider(
minimum=32,
maximum=2048,
value=512,
step=32,
label="Max Tokens",
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.05,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
)
chatbot = gr.ChatInterface(
fn=generate_with_assisted_decoding,
additional_inputs=[
model_choice,
max_tokens,
temperature,
top_p,
use_assisted,
],
examples=[
["What are the top 3 programming languages in 2024?"],
["Write a Python function to calculate fibonacci"],
["Explain quantum computing in simple terms"],
],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()