Spaces:
Sleeping
Sleeping
| """ | |
| LoRA utilities for Lily LLM API | |
| """ | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def setup_lora_for_model(profile, lora_manager): | |
| """๋ชจ๋ธ ํ๋กํ์ ๋ฐ๋ฅธ LoRA ์ค์ (๊ณตํต ํจ์)""" | |
| if not lora_manager: | |
| logger.warning("โ ๏ธ LoRA๊ฐ ์ฌ์ฉ ๋ถ๊ฐ๋ฅํ์ฌ ์๋ ์ค์ ๊ฑด๋๋") | |
| return False | |
| try: | |
| logger.info("๐ง LoRA ์๋ ์ค์ ์์...") | |
| # ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| current_model_path = None | |
| model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ | |
| # ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| if hasattr(profile, 'local_path') and profile.local_path: | |
| # ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ | |
| current_model_path = profile.local_path | |
| # ๐ local_path ์ฌ์ฉ ์์๋ model_type ์ค์ ํ์ | |
| if hasattr(profile, 'model_id') and profile.model_id: | |
| model_id = profile.model_id | |
| if model_id == "kanana-1.5-v-3b-instruct": | |
| model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์ | |
| else: | |
| model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ | |
| logger.info(f"๐ ๋ชจ๋ธ ํ๋กํ์์ ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ: {current_model_path}") | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์ : {model_type}") | |
| elif hasattr(profile, 'model_id') and profile.model_id: | |
| # ๋ชจ๋ธ ID๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฒฝ๋ก ๊ฒฐ์ | |
| model_id = profile.model_id | |
| logger.info(f"๐ ๋ชจ๋ธ ID ๊ธฐ๋ฐ ๊ฒฝ๋ก ๊ฒฐ์ : {model_id}") | |
| # ๐ ํ๊ฒฝ์ ๋ฐ๋ฅธ ๊ฒฝ๋ก ๊ฒฐ์ | |
| if hasattr(profile, 'is_local') and profile.is_local: | |
| # ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ | |
| if model_id == "polyglot-ko-1.3b-chat": | |
| current_model_path = "./lily_llm_core/models/polyglot_ko_1_3b_chat" | |
| model_type = "causal_lm" | |
| elif model_id == "kanana-1.5-v-3b-instruct": | |
| current_model_path = "./lily_llm_core/models/kanana_1_5_v_3b_instruct" | |
| model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์ | |
| elif model_id == "polyglot-ko-5.8b-chat": | |
| current_model_path = "./lily_llm_core/models/polyglot_ko_5_8b_chat" | |
| model_type = "causal_lm" | |
| else: | |
| # ๋ฐฐํฌ ํ๊ฒฝ: HF ๋ชจ๋ธ๋ช ์ฌ์ฉ (๋ก์ปฌ ๊ฒฝ๋ก ์์) | |
| current_model_path = None | |
| logger.info(f"๐ ๋ฐฐํฌ ํ๊ฒฝ: LoRA ์ค์ ๊ฑด๋๋ (HF ๋ชจ๋ธ)") | |
| return False | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}") | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์ : {model_type}") | |
| if not current_model_path: | |
| logger.warning("โ ๏ธ ํ์ฌ ๋ชจ๋ธ์ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ด LoRA ์๋ ๋ก๋ ๊ฑด๋๋") | |
| return False | |
| logger.info(f"๐ LoRA ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}") | |
| logger.info(f"๐ LoRA ๋ชจ๋ธ ํ์ : {model_type}") | |
| # ๐ ์ด๋ฏธ ๋ก๋๋ ๋ฉ์ธ ๋ชจ๋ธ์ LoRA์ ์ง์ ์ ์ฉ (์ค๋ณต ๋ก๋ ๋ฐฉ์ง) | |
| logger.info("๐ง ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ์ง์ ์ ์ฉ ์์...") | |
| # ๐ lora_manager์ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ ์ค์ | |
| if hasattr(lora_manager, 'base_model') and lora_manager.base_model is None: | |
| # ์ ์ญ ๋ณ์์์ ๋ฉ์ธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ | |
| from ..services.model_service import get_current_model | |
| current_model = get_current_model() | |
| if current_model is not None: | |
| lora_manager.base_model = current_model | |
| logger.info("โ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ๊ด๋ฆฌ์์ ์ค์ ์๋ฃ") | |
| else: | |
| logger.warning("โ ๏ธ ๋ฉ์ธ ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ด LoRA ์ค์ ๊ฑด๋๋") | |
| return False | |
| # LoRA ์ค์ ์์ฑ | |
| logger.info("๐ง LoRA ์ค์ ์์ฑ ์์...") | |
| # ๐ ๋ชจ๋ธ๋ณ target modules ์ค์ | |
| if model_type == "vision2seq" and "kanana" in profile.model_id: | |
| # Kanana ๋ชจ๋ธ: Llama ๊ธฐ๋ฐ language model ์ฌ์ฉ (์ฒซ ๋ฒ์งธ ๋ ์ด์ด๋ง ์ฌ์ฉ) | |
| target_modules = [ | |
| "language_model.model.layers.0.self_attn.q_proj", | |
| "language_model.model.layers.0.self_attn.k_proj", | |
| "language_model.model.layers.0.self_attn.v_proj", | |
| "language_model.model.layers.0.self_attn.o_proj", | |
| "language_model.model.layers.0.mlp.gate_proj", | |
| "language_model.model.layers.0.mlp.up_proj", | |
| "language_model.model.layers.0.mlp.down_proj" | |
| ] | |
| else: | |
| # ๊ธฐ์กด ๋ชจ๋ธ๋ค: GPTNeoX ๊ธฐ๋ฐ | |
| target_modules = ["query_key_value", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"] | |
| lora_config = lora_manager.create_lora_config( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type="CAUSAL_LM" if model_type == "causal_lm" else "VISION_2_SEQ", | |
| target_modules=target_modules | |
| ) | |
| logger.info("โ LoRA ์ค์ ์์ฑ ์๋ฃ") | |
| # LoRA ์ด๋ํฐ ์ ์ฉ (๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ ์ง์ ) | |
| logger.info("๐ง LoRA ์ด๋ํฐ ์ ์ฉ ์์...") | |
| adapter_success = lora_manager.apply_lora_to_model("auto_adapter") | |
| if adapter_success: | |
| logger.info("โ LoRA ์ด๋ํฐ ์ ์ฉ ์๋ฃ: auto_adapter") | |
| logger.info("๐ LoRA ์๋ ์ค์ ์๋ฃ!") | |
| return True | |
| else: | |
| logger.error("โ LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ") | |
| return False | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์๋ ์ค์ ์ค ์ค๋ฅ: {e}") | |
| return False | |