Spaces:
Running on T4
Running on T4
| """ | |
| Shared local Llama model for Layer 2 inference. | |
| Loads Llama 3.1 8B Instruct once via transformers and provides a | |
| chat_completion-like interface used by both CustomerSimulator and HFAgent, | |
| eliminating HF Inference API dependency and its transient errors. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| _shared_instance: LocalLlamaModel | None = None | |
| def get_shared_model( | |
| model_id: str = "unsloth/Meta-Llama-3.1-8B-Instruct", | |
| hf_token: str | None = None, | |
| device: str = "auto", | |
| ) -> LocalLlamaModel: | |
| """Get or create the singleton local model instance.""" | |
| global _shared_instance | |
| if _shared_instance is None: | |
| _shared_instance = LocalLlamaModel(model_id=model_id, hf_token=hf_token, device=device) | |
| return _shared_instance | |
| class LocalLlamaModel: | |
| """ | |
| Local Llama model loaded via transformers. | |
| Provides a generate() method with the same input format (list of | |
| message dicts) as the HF Inference API, so callers need minimal changes. | |
| """ | |
| def __init__( | |
| self, | |
| model_id: str = "unsloth/Meta-Llama-3.1-8B-Instruct", | |
| hf_token: str | None = None, | |
| device: str = "auto", | |
| ): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info("Loading local model: %s", model_id) | |
| self.model_id = model_id | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, token=hf_token, trust_remote_code=True, | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| trust_remote_code=True, | |
| ) | |
| self.model.eval() | |
| logger.info("Local model loaded on %s", self.model.device) | |
| def generate( | |
| self, | |
| messages: list[dict[str, str]], | |
| max_tokens: int = 200, | |
| temperature: float = 0.7, | |
| ) -> str: | |
| """Generate a completion from a list of chat messages.""" | |
| import torch | |
| input_text = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature if temperature > 0 else None, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode only the new tokens (exclude the prompt) | |
| new_tokens = output_ids[0][inputs["input_ids"].shape[1]:] | |
| return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |