Spaces:
Sleeping
Sleeping
| from typing import Any | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer | |
| from . import config, Model | |
| from .lazy_model import LazyModel | |
| MODEL_ID = Model.GEMMA_4_E2B.model_id | |
| lazy = LazyModel(MODEL_ID) | |
| processor = None | |
| model = None | |
| def clean_up(): | |
| global processor, model | |
| del processor | |
| del model | |
| def load(): | |
| global processor, model | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, **config.tokenizer_config) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config) | |
| def generate( | |
| messages: list[dict[str, str]], | |
| max_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| stop: list[str] | None = None, | |
| ) -> dict[str, Any]: | |
| global processor, model | |
| assert processor is not None, "Processor is not initialized." | |
| assert model is not None, "Model is not loaded." | |
| # Process input | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| inputs = processor(text=text, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| streamer = TextStreamer( | |
| processor.tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| with torch.inference_mode(): | |
| outputs = model.generate( # type: ignore | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| streamer=streamer, | |
| ) | |
| response = processor.decode(outputs[0][input_len:], skip_special_tokens=False) | |
| content = processor.parse_response(response) | |
| if isinstance(content, dict) and "content" in content: | |
| content = content["content"] | |
| prompt_tokens = len(processor.tokenizer.apply_chat_template(messages)) | |
| completion_tokens = len( | |
| processor.tokenizer.encode(content, add_special_tokens=False) | |
| ) | |
| print( | |
| f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}" | |
| ) | |
| print(f"Generated content: {content}") | |
| return { | |
| "model": MODEL_ID, | |
| "content": content, | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| }, | |
| } | |