Spaces:
Running
on
T4
Running
on
T4
| """๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ์ถ๋ก ๊ด๋ฆฌ""" | |
| import os | |
| import gc | |
| from typing import Dict, List, Tuple, Optional, Any | |
| from functools import lru_cache | |
| from pathlib import Path | |
| # Optional imports for when running with actual models | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| TORCH_AVAILABLE = True | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| torch = None | |
| from .model_registry import get_model_info, get_all_models, BASE_MODELS | |
| class ModelManager: | |
| """๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ์ถ๋ก ๊ด๋ฆฌ์""" | |
| def __init__( | |
| self, | |
| base_path: str = None, | |
| max_cached_models: int = 2, | |
| use_4bit: bool = True, | |
| device_map: str = "auto", | |
| ): | |
| if not TORCH_AVAILABLE: | |
| raise ImportError("torch, transformers, peft are required for ModelManager. Install with: pip install torch transformers peft") | |
| self.base_path = Path(base_path) if base_path else Path(__file__).parent.parent.parent | |
| self.max_cached_models = max_cached_models | |
| self.use_4bit = use_4bit | |
| self.device_map = device_map | |
| # ๋ก๋๋ ๋ชจ๋ธ ์บ์: {model_id: (model, tokenizer)} | |
| self._loaded_models: Dict[str, Tuple[Any, Any]] = {} | |
| self._load_order: List[str] = [] # LRU ์ถ์ | |
| # ์์ํ ์ค์ | |
| self.bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) if use_4bit else None | |
| def get_available_models(self) -> List[str]: | |
| """์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก""" | |
| return get_all_models() | |
| def _get_full_path(self, relative_path: str) -> Path: | |
| """์๋ ๊ฒฝ๋ก๋ฅผ ์ ๋ ๊ฒฝ๋ก๋ก ๋ณํ""" | |
| full_path = self.base_path / relative_path | |
| if full_path.exists(): | |
| return full_path | |
| return Path(relative_path) | |
| def _evict_if_needed(self): | |
| """์บ์๊ฐ ๊ฐ๋ ์ฐจ๋ฉด ๊ฐ์ฅ ์ค๋๋ ๋ชจ๋ธ ์ ๊ฑฐ""" | |
| while len(self._loaded_models) >= self.max_cached_models: | |
| if not self._load_order: | |
| break | |
| oldest_model_id = self._load_order.pop(0) | |
| if oldest_model_id in self._loaded_models: | |
| model, tokenizer = self._loaded_models.pop(oldest_model_id) | |
| del model | |
| del tokenizer | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"Evicted model: {oldest_model_id}") | |
| def load_model(self, model_id: str) -> Tuple[Any, Any]: | |
| """๋ชจ๋ธ ๋ก๋ (์บ์ ํ์ธ)""" | |
| # ์ด๋ฏธ ๋ก๋๋จ | |
| if model_id in self._loaded_models: | |
| # LRU ์ ๋ฐ์ดํธ | |
| if model_id in self._load_order: | |
| self._load_order.remove(model_id) | |
| self._load_order.append(model_id) | |
| return self._loaded_models[model_id] | |
| # ๋ชจ๋ธ ์ ๋ณด ์กฐํ | |
| info = get_model_info(model_id) | |
| if not info: | |
| raise ValueError(f"Unknown model: {model_id}") | |
| # ์บ์ ์ ๋ฆฌ | |
| self._evict_if_needed() | |
| # ๋ชจ๋ธ ๋ก๋ | |
| print(f"Loading model: {model_id}") | |
| base_model_name = info["base"] | |
| lora_path = self._get_full_path(info["path"]) | |
| # Tokenizer ๋ก๋ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_name, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Base ๋ชจ๋ธ ๋ก๋ | |
| model_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": self.device_map, | |
| } | |
| if self.use_4bit and self.bnb_config: | |
| model_kwargs["quantization_config"] = self.bnb_config | |
| else: | |
| model_kwargs["torch_dtype"] = torch.bfloat16 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| **model_kwargs | |
| ) | |
| # LoRA ์ด๋ํฐ ์ ์ฉ | |
| if lora_path.exists(): | |
| print(f"Loading LoRA adapter from: {lora_path}") | |
| model = PeftModel.from_pretrained(model, str(lora_path)) | |
| else: | |
| print(f"Warning: LoRA path not found: {lora_path}, using base model") | |
| model.eval() | |
| # ์บ์์ ์ ์ฅ | |
| self._loaded_models[model_id] = (model, tokenizer) | |
| self._load_order.append(model_id) | |
| print(f"Model loaded: {model_id}") | |
| return model, tokenizer | |
| def generate_response( | |
| self, | |
| model_id: str, | |
| messages: List[Dict[str, str]], | |
| system_prompt: str = "", | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| do_sample: bool = True, | |
| ) -> Tuple[str, Dict]: | |
| """์๋ต ์์ฑ""" | |
| import time | |
| model, tokenizer = self.load_model(model_id) | |
| # ๋ฉ์์ง ๊ตฌ์ฑ | |
| full_messages = [] | |
| if system_prompt: | |
| full_messages.append({"role": "system", "content": system_prompt}) | |
| full_messages.extend(messages) | |
| # ํ ํฌ๋์ด์ง | |
| try: | |
| text = tokenizer.apply_chat_template( | |
| full_messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| # apply_chat_template ์คํจ ์ ์๋ ํฌ๋งทํ | |
| text = self._format_messages_manual(full_messages) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # ์์ฑ | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| elapsed = time.time() - start_time | |
| # ๋์ฝ๋ฉ (์ ๋ ฅ ์ ์ธ) | |
| input_len = inputs["input_ids"].shape[1] | |
| response = tokenizer.decode( | |
| outputs[0][input_len:], | |
| skip_special_tokens=True, | |
| ) | |
| # ๋ฉํ๋ฐ์ดํฐ | |
| metadata = { | |
| "model_id": model_id, | |
| "latency_s": elapsed, | |
| "input_tokens": input_len, | |
| "output_tokens": len(outputs[0]) - input_len, | |
| "total_tokens": len(outputs[0]), | |
| } | |
| return response.strip(), metadata | |
| def _format_messages_manual(self, messages: List[Dict[str, str]]) -> str: | |
| """์๋ ๋ฉ์์ง ํฌ๋งทํ (apply_chat_template ์คํจ ์)""" | |
| formatted = "" | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| if role == "system": | |
| formatted += f"<|im_start|>system\n{content}<|im_end|>\n" | |
| elif role == "user": | |
| formatted += f"<|im_start|>user\n{content}<|im_end|>\n" | |
| elif role == "assistant": | |
| formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n" | |
| formatted += "<|im_start|>assistant\n" | |
| return formatted | |
| def unload_model(self, model_id: str): | |
| """ํน์ ๋ชจ๋ธ ์ธ๋ก๋""" | |
| if model_id in self._loaded_models: | |
| model, tokenizer = self._loaded_models.pop(model_id) | |
| if model_id in self._load_order: | |
| self._load_order.remove(model_id) | |
| del model | |
| del tokenizer | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"Unloaded model: {model_id}") | |
| def unload_all(self): | |
| """๋ชจ๋ ๋ชจ๋ธ ์ธ๋ก๋""" | |
| model_ids = list(self._loaded_models.keys()) | |
| for model_id in model_ids: | |
| self.unload_model(model_id) | |
| def get_loaded_models(self) -> List[str]: | |
| """ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ๋ชฉ๋ก""" | |
| return list(self._loaded_models.keys()) | |
| # ์ฑ๊ธํค ์ธ์คํด์ค | |
| _model_manager: Optional[ModelManager] = None | |
| def get_model_manager( | |
| base_path: str = None, | |
| max_cached_models: int = 2, | |
| use_4bit: bool = True, | |
| ) -> ModelManager: | |
| """ModelManager ์ฑ๊ธํค ์ธ์คํด์ค ๋ฐํ""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager( | |
| base_path=base_path, | |
| max_cached_models=max_cached_models, | |
| use_4bit=use_4bit, | |
| ) | |
| return _model_manager | |