Spaces:
Sleeping
Sleeping
File size: 5,956 Bytes
84635f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
LoRA utilities for Lily LLM API
"""
import logging
logger = logging.getLogger(__name__)
def setup_lora_for_model(profile, lora_manager):
"""๋ชจ๋ธ ํ๋กํ์ ๋ฐ๋ฅธ LoRA ์ค์ (๊ณตํต ํจ์)"""
if not lora_manager:
logger.warning("โ ๏ธ LoRA๊ฐ ์ฌ์ฉ ๋ถ๊ฐ๋ฅํ์ฌ ์๋ ์ค์ ๊ฑด๋๋")
return False
try:
logger.info("๐ง LoRA ์๋ ์ค์ ์์...")
# ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์
์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
current_model_path = None
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ
# ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์
์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
if hasattr(profile, 'local_path') and profile.local_path:
# ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ
current_model_path = profile.local_path
# ๐ local_path ์ฌ์ฉ ์์๋ model_type ์ค์ ํ์
if hasattr(profile, 'model_id') and profile.model_id:
model_id = profile.model_id
if model_id == "kanana-1.5-v-3b-instruct":
model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์
else:
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ
logger.info(f"๐ ๋ชจ๋ธ ํ๋กํ์์ ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ: {current_model_path}")
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์
: {model_type}")
elif hasattr(profile, 'model_id') and profile.model_id:
# ๋ชจ๋ธ ID๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฒฝ๋ก ๊ฒฐ์
model_id = profile.model_id
logger.info(f"๐ ๋ชจ๋ธ ID ๊ธฐ๋ฐ ๊ฒฝ๋ก ๊ฒฐ์ : {model_id}")
# ๐ ํ๊ฒฝ์ ๋ฐ๋ฅธ ๊ฒฝ๋ก ๊ฒฐ์
if hasattr(profile, 'is_local') and profile.is_local:
# ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ
if model_id == "polyglot-ko-1.3b-chat":
current_model_path = "./lily_llm_core/models/polyglot_ko_1_3b_chat"
model_type = "causal_lm"
elif model_id == "kanana-1.5-v-3b-instruct":
current_model_path = "./lily_llm_core/models/kanana_1_5_v_3b_instruct"
model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์
elif model_id == "polyglot-ko-5.8b-chat":
current_model_path = "./lily_llm_core/models/polyglot_ko_5_8b_chat"
model_type = "causal_lm"
else:
# ๋ฐฐํฌ ํ๊ฒฝ: HF ๋ชจ๋ธ๋ช
์ฌ์ฉ (๋ก์ปฌ ๊ฒฝ๋ก ์์)
current_model_path = None
logger.info(f"๐ ๋ฐฐํฌ ํ๊ฒฝ: LoRA ์ค์ ๊ฑด๋๋ (HF ๋ชจ๋ธ)")
return False
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}")
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์
: {model_type}")
if not current_model_path:
logger.warning("โ ๏ธ ํ์ฌ ๋ชจ๋ธ์ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ด LoRA ์๋ ๋ก๋ ๊ฑด๋๋")
return False
logger.info(f"๐ LoRA ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}")
logger.info(f"๐ LoRA ๋ชจ๋ธ ํ์
: {model_type}")
# ๐ ์ด๋ฏธ ๋ก๋๋ ๋ฉ์ธ ๋ชจ๋ธ์ LoRA์ ์ง์ ์ ์ฉ (์ค๋ณต ๋ก๋ ๋ฐฉ์ง)
logger.info("๐ง ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ์ง์ ์ ์ฉ ์์...")
# ๐ lora_manager์ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ ์ค์
if hasattr(lora_manager, 'base_model') and lora_manager.base_model is None:
# ์ ์ญ ๋ณ์์์ ๋ฉ์ธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ
from ..services.model_service import get_current_model
current_model = get_current_model()
if current_model is not None:
lora_manager.base_model = current_model
logger.info("โ
๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ๊ด๋ฆฌ์์ ์ค์ ์๋ฃ")
else:
logger.warning("โ ๏ธ ๋ฉ์ธ ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ด LoRA ์ค์ ๊ฑด๋๋")
return False
# LoRA ์ค์ ์์ฑ
logger.info("๐ง LoRA ์ค์ ์์ฑ ์์...")
# ๐ ๋ชจ๋ธ๋ณ target modules ์ค์
if model_type == "vision2seq" and "kanana" in profile.model_id:
# Kanana ๋ชจ๋ธ: Llama ๊ธฐ๋ฐ language model ์ฌ์ฉ (์ฒซ ๋ฒ์งธ ๋ ์ด์ด๋ง ์ฌ์ฉ)
target_modules = [
"language_model.model.layers.0.self_attn.q_proj",
"language_model.model.layers.0.self_attn.k_proj",
"language_model.model.layers.0.self_attn.v_proj",
"language_model.model.layers.0.self_attn.o_proj",
"language_model.model.layers.0.mlp.gate_proj",
"language_model.model.layers.0.mlp.up_proj",
"language_model.model.layers.0.mlp.down_proj"
]
else:
# ๊ธฐ์กด ๋ชจ๋ธ๋ค: GPTNeoX ๊ธฐ๋ฐ
target_modules = ["query_key_value", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"]
lora_config = lora_manager.create_lora_config(
r=16,
lora_alpha=32,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM" if model_type == "causal_lm" else "VISION_2_SEQ",
target_modules=target_modules
)
logger.info("โ
LoRA ์ค์ ์์ฑ ์๋ฃ")
# LoRA ์ด๋ํฐ ์ ์ฉ (๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ ์ง์ )
logger.info("๐ง LoRA ์ด๋ํฐ ์ ์ฉ ์์...")
adapter_success = lora_manager.apply_lora_to_model("auto_adapter")
if adapter_success:
logger.info("โ
LoRA ์ด๋ํฐ ์ ์ฉ ์๋ฃ: auto_adapter")
logger.info("๐ LoRA ์๋ ์ค์ ์๋ฃ!")
return True
else:
logger.error("โ LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ")
return False
except Exception as e:
logger.error(f"โ LoRA ์๋ ์ค์ ์ค ์ค๋ฅ: {e}")
return False
|