model-gateway-api / models /gemma4_e2b.py
autumnssuns's picture
✨ Implement lazy loading for models and correct tokens counting
e4b3020
from typing import Any
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
from . import config, Model
from .lazy_model import LazyModel
MODEL_ID = Model.GEMMA_4_E2B.model_id
lazy = LazyModel(MODEL_ID)
processor = None
model = None
@lazy.unload()
def clean_up():
global processor, model
del processor
del model
@lazy.load()
def load():
global processor, model
processor = AutoProcessor.from_pretrained(MODEL_ID, **config.tokenizer_config)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
@lazy.entry()
def generate(
messages: list[dict[str, str]],
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
stop: list[str] | None = None,
) -> dict[str, Any]:
global processor, model
assert processor is not None, "Processor is not initialized."
assert model is not None, "Model is not loaded."
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = processor(text=text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
streamer = TextStreamer(
processor.tokenizer, skip_prompt=True, skip_special_tokens=True
)
with torch.inference_mode():
outputs = model.generate( # type: ignore
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
streamer=streamer,
)
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
content = processor.parse_response(response)
if isinstance(content, dict) and "content" in content:
content = content["content"]
prompt_tokens = len(processor.tokenizer.apply_chat_template(messages))
completion_tokens = len(
processor.tokenizer.encode(content, add_special_tokens=False)
)
print(
f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
)
print(f"Generated content: {content}")
return {
"model": MODEL_ID,
"content": content,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}