Spaces:
Running on T4
Running on T4
File size: 2,864 Bytes
10418d0 6506d63 10418d0 6506d63 10418d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """
Shared local Llama model for Layer 2 inference.
Loads Llama 3.1 8B Instruct once via transformers and provides a
chat_completion-like interface used by both CustomerSimulator and HFAgent,
eliminating HF Inference API dependency and its transient errors.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
_shared_instance: LocalLlamaModel | None = None
def get_shared_model(
model_id: str = "unsloth/Meta-Llama-3.1-8B-Instruct",
hf_token: str | None = None,
device: str = "auto",
) -> LocalLlamaModel:
"""Get or create the singleton local model instance."""
global _shared_instance
if _shared_instance is None:
_shared_instance = LocalLlamaModel(model_id=model_id, hf_token=hf_token, device=device)
return _shared_instance
class LocalLlamaModel:
"""
Local Llama model loaded via transformers.
Provides a generate() method with the same input format (list of
message dicts) as the HF Inference API, so callers need minimal changes.
"""
def __init__(
self,
model_id: str = "unsloth/Meta-Llama-3.1-8B-Instruct",
hf_token: str | None = None,
device: str = "auto",
):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("Loading local model: %s", model_id)
self.model_id = model_id
self.tokenizer = AutoTokenizer.from_pretrained(
model_id, token=hf_token, trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True,
)
self.model.eval()
logger.info("Local model loaded on %s", self.model.device)
def generate(
self,
messages: list[dict[str, str]],
max_tokens: int = 200,
temperature: float = 0.7,
) -> str:
"""Generate a completion from a list of chat messages."""
import torch
input_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature if temperature > 0 else None,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode only the new tokens (exclude the prompt)
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|