gbrabbit's picture
Auto commit at 23-2025-08 10:05:36
84635f1
"""
LoRA utilities for Lily LLM API
"""
import logging
logger = logging.getLogger(__name__)
def setup_lora_for_model(profile, lora_manager):
"""๋ชจ๋ธ ํ”„๋กœํ•„์— ๋”ฐ๋ฅธ LoRA ์„ค์ • (๊ณตํ†ต ํ•จ์ˆ˜)"""
if not lora_manager:
logger.warning("โš ๏ธ LoRA๊ฐ€ ์‚ฌ์šฉ ๋ถˆ๊ฐ€๋Šฅํ•˜์—ฌ ์ž๋™ ์„ค์ • ๊ฑด๋„ˆ๋œ€")
return False
try:
logger.info("๐Ÿ”ง LoRA ์ž๋™ ์„ค์ • ์‹œ์ž‘...")
# ๐Ÿ”„ ๋ชจ๋ธ ํ”„๋กœํ•„์—์„œ ๊ฒฝ๋กœ ๋ฐ ํƒ€์ž… ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
current_model_path = None
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ’
# ๐Ÿ”„ ๋ชจ๋ธ ํ”„๋กœํ•„์—์„œ ๊ฒฝ๋กœ ๋ฐ ํƒ€์ž… ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
if hasattr(profile, 'local_path') and profile.local_path:
# ๋กœ์ปฌ ํ™˜๊ฒฝ: ๋กœ์ปฌ ๊ฒฝ๋กœ ์‚ฌ์šฉ
current_model_path = profile.local_path
# ๐Ÿ”„ local_path ์‚ฌ์šฉ ์‹œ์—๋„ model_type ์„ค์ • ํ•„์š”
if hasattr(profile, 'model_id') and profile.model_id:
model_id = profile.model_id
if model_id == "kanana-1.5-v-3b-instruct":
model_type = "vision2seq" # ๐Ÿ”„ kanana๋Š” vision2seq ํƒ€์ž…
else:
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ’
logger.info(f"๐Ÿ” ๋ชจ๋ธ ํ”„๋กœํ•„์—์„œ ๋กœ์ปฌ ๊ฒฝ๋กœ ์‚ฌ์šฉ: {current_model_path}")
logger.info(f"๐Ÿ” ๊ฒฐ์ •๋œ ๋ชจ๋ธ ํƒ€์ž…: {model_type}")
elif hasattr(profile, 'model_id') and profile.model_id:
# ๋ชจ๋ธ ID๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฒฝ๋กœ ๊ฒฐ์ •
model_id = profile.model_id
logger.info(f"๐Ÿ” ๋ชจ๋ธ ID ๊ธฐ๋ฐ˜ ๊ฒฝ๋กœ ๊ฒฐ์ •: {model_id}")
# ๐Ÿ”„ ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ ๊ฒฝ๋กœ ๊ฒฐ์ •
if hasattr(profile, 'is_local') and profile.is_local:
# ๋กœ์ปฌ ํ™˜๊ฒฝ: ๋กœ์ปฌ ๊ฒฝ๋กœ ์‚ฌ์šฉ
if model_id == "polyglot-ko-1.3b-chat":
current_model_path = "./lily_llm_core/models/polyglot_ko_1_3b_chat"
model_type = "causal_lm"
elif model_id == "kanana-1.5-v-3b-instruct":
current_model_path = "./lily_llm_core/models/kanana_1_5_v_3b_instruct"
model_type = "vision2seq" # ๐Ÿ”„ kanana๋Š” vision2seq ํƒ€์ž…
elif model_id == "polyglot-ko-5.8b-chat":
current_model_path = "./lily_llm_core/models/polyglot_ko_5_8b_chat"
model_type = "causal_lm"
else:
# ๋ฐฐํฌ ํ™˜๊ฒฝ: HF ๋ชจ๋ธ๋ช… ์‚ฌ์šฉ (๋กœ์ปฌ ๊ฒฝ๋กœ ์—†์Œ)
current_model_path = None
logger.info(f"๐Ÿ” ๋ฐฐํฌ ํ™˜๊ฒฝ: LoRA ์„ค์ • ๊ฑด๋„ˆ๋œ€ (HF ๋ชจ๋ธ)")
return False
logger.info(f"๐Ÿ” ๊ฒฐ์ •๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ: {current_model_path}")
logger.info(f"๐Ÿ” ๊ฒฐ์ •๋œ ๋ชจ๋ธ ํƒ€์ž…: {model_type}")
if not current_model_path:
logger.warning("โš ๏ธ ํ˜„์žฌ ๋ชจ๋ธ์˜ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด LoRA ์ž๋™ ๋กœ๋“œ ๊ฑด๋„ˆ๋œ€")
return False
logger.info(f"๐Ÿ” LoRA ๋ชจ๋ธ ๊ฒฝ๋กœ: {current_model_path}")
logger.info(f"๐Ÿ” LoRA ๋ชจ๋ธ ํƒ€์ž…: {model_type}")
# ๐Ÿ”„ ์ด๋ฏธ ๋กœ๋“œ๋œ ๋ฉ”์ธ ๋ชจ๋ธ์„ LoRA์— ์ง์ ‘ ์ ์šฉ (์ค‘๋ณต ๋กœ๋“œ ๋ฐฉ์ง€)
logger.info("๐Ÿ”ง ๊ธฐ์กด ๋ฉ”์ธ ๋ชจ๋ธ์— LoRA ์ง์ ‘ ์ ์šฉ ์‹œ์ž‘...")
# ๐Ÿ”„ lora_manager์— ๊ธฐ์กด ๋ฉ”์ธ ๋ชจ๋ธ ์„ค์ •
if hasattr(lora_manager, 'base_model') and lora_manager.base_model is None:
# ์ „์—ญ ๋ณ€์ˆ˜์—์„œ ๋ฉ”์ธ ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
from ..services.model_service import get_current_model
current_model = get_current_model()
if current_model is not None:
lora_manager.base_model = current_model
logger.info("โœ… ๊ธฐ์กด ๋ฉ”์ธ ๋ชจ๋ธ์„ LoRA ๊ด€๋ฆฌ์ž์— ์„ค์ • ์™„๋ฃŒ")
else:
logger.warning("โš ๏ธ ๋ฉ”์ธ ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์–ด LoRA ์„ค์ • ๊ฑด๋„ˆ๋œ€")
return False
# LoRA ์„ค์ • ์ƒ์„ฑ
logger.info("๐Ÿ”ง LoRA ์„ค์ • ์ƒ์„ฑ ์‹œ์ž‘...")
# ๐Ÿ”„ ๋ชจ๋ธ๋ณ„ target modules ์„ค์ •
if model_type == "vision2seq" and "kanana" in profile.model_id:
# Kanana ๋ชจ๋ธ: Llama ๊ธฐ๋ฐ˜ language model ์‚ฌ์šฉ (์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด๋งŒ ์‚ฌ์šฉ)
target_modules = [
"language_model.model.layers.0.self_attn.q_proj",
"language_model.model.layers.0.self_attn.k_proj",
"language_model.model.layers.0.self_attn.v_proj",
"language_model.model.layers.0.self_attn.o_proj",
"language_model.model.layers.0.mlp.gate_proj",
"language_model.model.layers.0.mlp.up_proj",
"language_model.model.layers.0.mlp.down_proj"
]
else:
# ๊ธฐ์กด ๋ชจ๋ธ๋“ค: GPTNeoX ๊ธฐ๋ฐ˜
target_modules = ["query_key_value", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"]
lora_config = lora_manager.create_lora_config(
r=16,
lora_alpha=32,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM" if model_type == "causal_lm" else "VISION_2_SEQ",
target_modules=target_modules
)
logger.info("โœ… LoRA ์„ค์ • ์ƒ์„ฑ ์™„๋ฃŒ")
# LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ (๊ธฐ์กด ๋ฉ”์ธ ๋ชจ๋ธ์— ์ง์ ‘)
logger.info("๐Ÿ”ง LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ ์‹œ์ž‘...")
adapter_success = lora_manager.apply_lora_to_model("auto_adapter")
if adapter_success:
logger.info("โœ… LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ ์™„๋ฃŒ: auto_adapter")
logger.info("๐ŸŽ‰ LoRA ์ž๋™ ์„ค์ • ์™„๋ฃŒ!")
return True
else:
logger.error("โŒ LoRA ์–ด๋Œ‘ํ„ฐ ์ ์šฉ ์‹คํŒจ")
return False
except Exception as e:
logger.error(f"โŒ LoRA ์ž๋™ ์„ค์ • ์ค‘ ์˜ค๋ฅ˜: {e}")
return False