lily_fast_api / lily_llm_api /services /model_service.py
gbrabbit's picture
Auto commit at 25-2025-08 3:12:15
0e9a45c
"""
Model service for Lily LLM API
"""
import logging
import os
import asyncio
import concurrent.futures
from typing import Optional
logger = logging.getLogger(__name__)
# ์ „์—ญ ๋ณ€์ˆ˜๋“ค
current_model = None # ๐Ÿ”„ ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค
current_profile = None # ๐Ÿ”„ ํ˜„์žฌ ์„ ํƒ๋œ ๋ชจ๋ธ ํ”„๋กœํ•„
model_loaded = False # ๐Ÿ”„ ๋ชจ๋ธ ๋กœ๋“œ ์ƒํƒœ
model = None
tokenizer = None
processor = None
executor = concurrent.futures.ThreadPoolExecutor()
def get_current_model():
"""ํ˜„์žฌ ๋กœ๋“œ๋œ ๋ชจ๋ธ ๋ฐ˜ํ™˜"""
return current_model
def get_current_profile():
"""ํ˜„์žฌ ์„ ํƒ๋œ ๋ชจ๋ธ ํ”„๋กœํ•„ ๋ฐ˜ํ™˜"""
return current_profile
def is_model_loaded():
"""๋ชจ๋ธ ๋กœ๋“œ ์ƒํƒœ ๋ฐ˜ํ™˜"""
return model_loaded
async def load_model_async(model_id: str):
"""๋ชจ๋ธ์„ ๋น„๋™๊ธฐ์ ์œผ๋กœ ๋กœ๋”ฉ"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(executor, load_model_sync, model_id)
def load_model_sync(model_id: str):
"""๋ชจ๋ธ ๋ฐ ๊ด€๋ จ ํ”„๋กœ์„ธ์„œ๋ฅผ ๋™๊ธฐ์ ์œผ๋กœ ๋กœ๋”ฉ (์ตœ์ข… ์ˆ˜์ •๋ณธ)"""
global model, tokenizer, processor, current_profile, current_model, model_loaded
try:
if model is not None:
logger.info("๐Ÿ—‘๏ธ ๊ธฐ์กด ๋ชจ๋ธ ์–ธ๋กœ๋“œ ์ค‘...")
del model
del tokenizer
del processor
model, tokenizer, processor = None, None, None
import gc
gc.collect()
logger.info("โœ… ๊ธฐ์กด ๋ชจ๋ธ ์–ธ๋กœ๋“œ ์™„๋ฃŒ")
logger.info(f"๐Ÿ“ฅ '{model_id}' ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
from ..models import get_model_profile
current_profile = get_model_profile(model_id)
# ์ด์ œ load_model์€ (model, processor)๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
model, processor = current_profile.load_model()
# ๐Ÿ”ง ์„œ๋ฒ„ ์‹œ์ž‘ ์‹œ์ ์—์„œ dtype ๊ฐ•์ œ ์ ์šฉ (์ฒซ ์š”์ฒญ ์ง€์—ฐ ๋ฐฉ์ง€)
try:
import torch as _torch
# ๋””๋ฐ”์ด์Šค๋ณ„ ๋Œ€์ƒ dtype ๊ฒฐ์ • (๊ธฐ๋ณธ: CPU=float32, CUDA=bfloat16)
if hasattr(model, 'device') and str(model.device) == 'cpu':
desired = (os.getenv('LILY_FORCE_DTYPE') or os.getenv('LILY_CPU_DTYPE') or 'float32').lower()
default_target = _torch.float32
else:
desired = (os.getenv('LILY_FORCE_DTYPE') or os.getenv('LILY_CUDA_DTYPE') or 'bfloat16').lower()
default_target = _torch.bfloat16
desired_map = {
'float32': _torch.float32,
'fp32': _torch.float32,
'bfloat16': _torch.bfloat16,
'bf16': _torch.bfloat16,
'float16': _torch.float16,
'fp16': _torch.float16,
}
target_dtype = desired_map.get(desired, default_target)
if hasattr(model, 'dtype') and model.dtype != target_dtype:
logger.info(f"๐Ÿ”ง [SPEED][startup] dtype ์ ์šฉ: {model.dtype} -> {target_dtype}")
model = model.to(target_dtype)
except Exception as _dtype_e:
logger.warning(f"โš ๏ธ [startup] dtype ์ ์šฉ ์‹คํŒจ: {_dtype_e}")
# ๐Ÿ”„ ์ „์—ญ ๋ณ€์ˆ˜์— ๋ชจ๋ธ ์„ค์ • (LoRA์—์„œ ์‚ฌ์šฉ)
current_model = model
# processor์—์„œ tokenizer๋ฅผ ๊บผ๋‚ด ์ „์—ญ ๋ณ€์ˆ˜์— ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.
if hasattr(processor, 'tokenizer'):
tokenizer = processor.tokenizer
else:
# processor ์ž์ฒด๊ฐ€ tokenizer ์—ญํ• ๋„ ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ
tokenizer = processor
logger.info(f"โœ… '{current_profile.display_name}' ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
# ๐Ÿ”„ LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ์ž๋™ ๋กœ๋“œ (๊ณตํ†ต ํ•จ์ˆ˜ ์‚ฌ์šฉ)
try:
from lily_llm_core.lora_manager import get_lora_manager, lora_manager
if lora_manager:
from ..utils.lora_utils import setup_lora_for_model
setup_lora_for_model(current_profile, lora_manager)
except ImportError:
logger.warning("โš ๏ธ LoRA ๊ด€๋ฆฌ์ž import ์‹คํŒจ")
model_loaded = True
except Exception as e:
logger.error(f"โŒ load_model_sync ์‹คํŒจ: {e}")
import traceback
logger.error(f"๐Ÿ” ์ „์ฒด ์—๋Ÿฌ: {traceback.format_exc()}")
model_loaded = False
raise
def shutdown_executor():
"""์Šค๋ ˆ๋“œ ํ’€ ์‹คํ–‰๊ธฐ ์ข…๋ฃŒ"""
executor.shutdown(wait=True)