"""모델 로딩 및 추론 관리""" import os import gc from typing import Dict, List, Tuple, Optional, Any from functools import lru_cache from pathlib import Path # Optional imports for when running with actual models try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False torch = None from .model_registry import get_model_info, get_all_models, BASE_MODELS class ModelManager: """모델 로딩 및 추론 관리자""" def __init__( self, base_path: str = None, max_cached_models: int = 2, use_4bit: bool = True, device_map: str = "auto", ): if not TORCH_AVAILABLE: raise ImportError("torch, transformers, peft are required for ModelManager. Install with: pip install torch transformers peft") self.base_path = Path(base_path) if base_path else Path(__file__).parent.parent.parent self.max_cached_models = max_cached_models self.use_4bit = use_4bit self.device_map = device_map # 로드된 모델 캐시: {model_id: (model, tokenizer)} self._loaded_models: Dict[str, Tuple[Any, Any]] = {} self._load_order: List[str] = [] # LRU 추적 # 양자화 설정 self.bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) if use_4bit else None def get_available_models(self) -> List[str]: """사용 가능한 모델 목록""" return get_all_models() def _get_full_path(self, relative_path: str) -> Path: """상대 경로를 절대 경로로 변환""" full_path = self.base_path / relative_path if full_path.exists(): return full_path return Path(relative_path) def _evict_if_needed(self): """캐시가 가득 차면 가장 오래된 모델 제거""" while len(self._loaded_models) >= self.max_cached_models: if not self._load_order: break oldest_model_id = self._load_order.pop(0) if oldest_model_id in self._loaded_models: model, tokenizer = self._loaded_models.pop(oldest_model_id) del model del tokenizer gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Evicted model: {oldest_model_id}") def load_model(self, model_id: str) -> Tuple[Any, Any]: """모델 로드 (캐시 확인)""" # 이미 로드됨 if model_id in self._loaded_models: # LRU 업데이트 if model_id in self._load_order: self._load_order.remove(model_id) self._load_order.append(model_id) return self._loaded_models[model_id] # 모델 정보 조회 info = get_model_info(model_id) if not info: raise ValueError(f"Unknown model: {model_id}") # 캐시 정리 self._evict_if_needed() # 모델 로드 print(f"Loading model: {model_id}") base_model_name = info["base"] lora_path = self._get_full_path(info["path"]) # Tokenizer 로드 tokenizer = AutoTokenizer.from_pretrained( base_model_name, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Base 모델 로드 model_kwargs = { "trust_remote_code": True, "device_map": self.device_map, } if self.use_4bit and self.bnb_config: model_kwargs["quantization_config"] = self.bnb_config else: model_kwargs["torch_dtype"] = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained( base_model_name, **model_kwargs ) # LoRA 어댑터 적용 if lora_path.exists(): print(f"Loading LoRA adapter from: {lora_path}") model = PeftModel.from_pretrained(model, str(lora_path)) else: print(f"Warning: LoRA path not found: {lora_path}, using base model") model.eval() # 캐시에 저장 self._loaded_models[model_id] = (model, tokenizer) self._load_order.append(model_id) print(f"Model loaded: {model_id}") return model, tokenizer def generate_response( self, model_id: str, messages: List[Dict[str, str]], system_prompt: str = "", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True, ) -> Tuple[str, Dict]: """응답 생성""" import time model, tokenizer = self.load_model(model_id) # 메시지 구성 full_messages = [] if system_prompt: full_messages.append({"role": "system", "content": system_prompt}) full_messages.extend(messages) # 토크나이징 try: text = tokenizer.apply_chat_template( full_messages, tokenize=False, add_generation_prompt=True, ) except Exception: # apply_chat_template 실패 시 수동 포맷팅 text = self._format_messages_manual(full_messages) inputs = tokenizer(text, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} # 생성 start_time = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) elapsed = time.time() - start_time # 디코딩 (입력 제외) input_len = inputs["input_ids"].shape[1] response = tokenizer.decode( outputs[0][input_len:], skip_special_tokens=True, ) # 메타데이터 metadata = { "model_id": model_id, "latency_s": elapsed, "input_tokens": input_len, "output_tokens": len(outputs[0]) - input_len, "total_tokens": len(outputs[0]), } return response.strip(), metadata def _format_messages_manual(self, messages: List[Dict[str, str]]) -> str: """수동 메시지 포맷팅 (apply_chat_template 실패 시)""" formatted = "" for msg in messages: role = msg["role"] content = msg["content"] if role == "system": formatted += f"<|im_start|>system\n{content}<|im_end|>\n" elif role == "user": formatted += f"<|im_start|>user\n{content}<|im_end|>\n" elif role == "assistant": formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n" formatted += "<|im_start|>assistant\n" return formatted def unload_model(self, model_id: str): """특정 모델 언로드""" if model_id in self._loaded_models: model, tokenizer = self._loaded_models.pop(model_id) if model_id in self._load_order: self._load_order.remove(model_id) del model del tokenizer gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"Unloaded model: {model_id}") def unload_all(self): """모든 모델 언로드""" model_ids = list(self._loaded_models.keys()) for model_id in model_ids: self.unload_model(model_id) def get_loaded_models(self) -> List[str]: """현재 로드된 모델 목록""" return list(self._loaded_models.keys()) # 싱글톤 인스턴스 _model_manager: Optional[ModelManager] = None def get_model_manager( base_path: str = None, max_cached_models: int = 2, use_4bit: bool = True, ) -> ModelManager: """ModelManager 싱글톤 인스턴스 반환""" global _model_manager if _model_manager is None: _model_manager = ModelManager( base_path=base_path, max_cached_models=max_cached_models, use_4bit=use_4bit, ) return _model_manager