developer-lunark's picture
Upload models/model_manager.py with huggingface_hub
9c7416c verified
"""๋ชจ๋ธ ๋กœ๋”ฉ ๋ฐ ์ถ”๋ก  ๊ด€๋ฆฌ"""
import os
import gc
from typing import Dict, List, Tuple, Optional, Any
from functools import lru_cache
from pathlib import Path
# Optional imports for when running with actual models
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
from .model_registry import get_model_info, get_all_models, BASE_MODELS
class ModelManager:
"""๋ชจ๋ธ ๋กœ๋”ฉ ๋ฐ ์ถ”๋ก  ๊ด€๋ฆฌ์ž"""
def __init__(
self,
base_path: str = None,
max_cached_models: int = 2,
use_4bit: bool = True,
device_map: str = "auto",
):
if not TORCH_AVAILABLE:
raise ImportError("torch, transformers, peft are required for ModelManager. Install with: pip install torch transformers peft")
self.base_path = Path(base_path) if base_path else Path(__file__).parent.parent.parent
self.max_cached_models = max_cached_models
self.use_4bit = use_4bit
self.device_map = device_map
# ๋กœ๋“œ๋œ ๋ชจ๋ธ ์บ์‹œ: {model_id: (model, tokenizer)}
self._loaded_models: Dict[str, Tuple[Any, Any]] = {}
self._load_order: List[str] = [] # LRU ์ถ”์ 
# ์–‘์žํ™” ์„ค์ •
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
) if use_4bit else None
def get_available_models(self) -> List[str]:
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก"""
return get_all_models()
def _get_full_path(self, relative_path: str) -> Path:
"""์ƒ๋Œ€ ๊ฒฝ๋กœ๋ฅผ ์ ˆ๋Œ€ ๊ฒฝ๋กœ๋กœ ๋ณ€ํ™˜"""
full_path = self.base_path / relative_path
if full_path.exists():
return full_path
return Path(relative_path)
def _evict_if_needed(self):
"""์บ์‹œ๊ฐ€ ๊ฐ€๋“ ์ฐจ๋ฉด ๊ฐ€์žฅ ์˜ค๋ž˜๋œ ๋ชจ๋ธ ์ œ๊ฑฐ"""
while len(self._loaded_models) >= self.max_cached_models:
if not self._load_order:
break
oldest_model_id = self._load_order.pop(0)
if oldest_model_id in self._loaded_models:
model, tokenizer = self._loaded_models.pop(oldest_model_id)
del model
del tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Evicted model: {oldest_model_id}")
def load_model(self, model_id: str) -> Tuple[Any, Any]:
"""๋ชจ๋ธ ๋กœ๋“œ (์บ์‹œ ํ™•์ธ)"""
# ์ด๋ฏธ ๋กœ๋“œ๋จ
if model_id in self._loaded_models:
# LRU ์—…๋ฐ์ดํŠธ
if model_id in self._load_order:
self._load_order.remove(model_id)
self._load_order.append(model_id)
return self._loaded_models[model_id]
# ๋ชจ๋ธ ์ •๋ณด ์กฐํšŒ
info = get_model_info(model_id)
if not info:
raise ValueError(f"Unknown model: {model_id}")
# ์บ์‹œ ์ •๋ฆฌ
self._evict_if_needed()
# ๋ชจ๋ธ ๋กœ๋“œ
print(f"Loading model: {model_id}")
base_model_name = info["base"]
lora_path = self._get_full_path(info["path"])
# Tokenizer ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Base ๋ชจ๋ธ ๋กœ๋“œ
model_kwargs = {
"trust_remote_code": True,
"device_map": self.device_map,
}
if self.use_4bit and self.bnb_config:
model_kwargs["quantization_config"] = self.bnb_config
else:
model_kwargs["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
**model_kwargs
)
# LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ
if lora_path.exists():
print(f"Loading LoRA adapter from: {lora_path}")
model = PeftModel.from_pretrained(model, str(lora_path))
else:
print(f"Warning: LoRA path not found: {lora_path}, using base model")
model.eval()
# ์บ์‹œ์— ์ €์žฅ
self._loaded_models[model_id] = (model, tokenizer)
self._load_order.append(model_id)
print(f"Model loaded: {model_id}")
return model, tokenizer
def generate_response(
self,
model_id: str,
messages: List[Dict[str, str]],
system_prompt: str = "",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
) -> Tuple[str, Dict]:
"""์‘๋‹ต ์ƒ์„ฑ"""
import time
model, tokenizer = self.load_model(model_id)
# ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
# ํ† ํฌ๋‚˜์ด์ง•
try:
text = tokenizer.apply_chat_template(
full_messages,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
# apply_chat_template ์‹คํŒจ ์‹œ ์ˆ˜๋™ ํฌ๋งทํŒ…
text = self._format_messages_manual(full_messages)
inputs = tokenizer(text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# ์ƒ์„ฑ
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
elapsed = time.time() - start_time
# ๋””์ฝ”๋”ฉ (์ž…๋ ฅ ์ œ์™ธ)
input_len = inputs["input_ids"].shape[1]
response = tokenizer.decode(
outputs[0][input_len:],
skip_special_tokens=True,
)
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
metadata = {
"model_id": model_id,
"latency_s": elapsed,
"input_tokens": input_len,
"output_tokens": len(outputs[0]) - input_len,
"total_tokens": len(outputs[0]),
}
return response.strip(), metadata
def _format_messages_manual(self, messages: List[Dict[str, str]]) -> str:
"""์ˆ˜๋™ ๋ฉ”์‹œ์ง€ ํฌ๋งทํŒ… (apply_chat_template ์‹คํŒจ ์‹œ)"""
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n"
formatted += "<|im_start|>assistant\n"
return formatted
def unload_model(self, model_id: str):
"""ํŠน์ • ๋ชจ๋ธ ์–ธ๋กœ๋“œ"""
if model_id in self._loaded_models:
model, tokenizer = self._loaded_models.pop(model_id)
if model_id in self._load_order:
self._load_order.remove(model_id)
del model
del tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Unloaded model: {model_id}")
def unload_all(self):
"""๋ชจ๋“  ๋ชจ๋ธ ์–ธ๋กœ๋“œ"""
model_ids = list(self._loaded_models.keys())
for model_id in model_ids:
self.unload_model(model_id)
def get_loaded_models(self) -> List[str]:
"""ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ๋ชฉ๋ก"""
return list(self._loaded_models.keys())
# ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค
_model_manager: Optional[ModelManager] = None
def get_model_manager(
base_path: str = None,
max_cached_models: int = 2,
use_4bit: bool = True,
) -> ModelManager:
"""ModelManager ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค ๋ฐ˜ํ™˜"""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager(
base_path=base_path,
max_cached_models=max_cached_models,
use_4bit=use_4bit,
)
return _model_manager