Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Lily LLM ๋ชจ๋ธ ์ ํธ๋ฆฌํฐ | |
| ๋ชจ๋ธ ๋ก๋ฉ, ์ถ๋ก , ์ต์ ํ ๊ด๋ จ ํจ์๋ค | |
| """ | |
| import torch | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class LilyModelManager: | |
| """Lily LLM ๋ชจ๋ธ ๊ด๋ฆฌ์""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_loaded = False | |
| self.model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
| self.lora_path = "hearth_llm_model" | |
| def load_model(self, device: str = "cpu") -> bool: | |
| """๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋""" | |
| try: | |
| logger.info("๋ชจ๋ธ ๋ก๋ฉ ์์...") | |
| # ํ ํฌ๋์ด์ ๋ก๋ | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| use_fast=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # ๋ชจ๋ธ ๋ก๋ | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float32, | |
| device_map=device, | |
| low_cpu_mem_usage=True | |
| ) | |
| # LoRA ์ด๋ํฐ ๋ก๋ (ํ์ธํ๋๋ ๋ชจ๋ธ) | |
| try: | |
| self.model = PeftModel.from_pretrained(self.model, self.lora_path) | |
| logger.info("LoRA ์ด๋ํฐ ๋ก๋ ์ฑ๊ณต") | |
| except Exception as e: | |
| logger.warning(f"LoRA ์ด๋ํฐ ๋ก๋ ์คํจ, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ: {e}") | |
| self.model_loaded = True | |
| logger.info("โ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| self.model_loaded = False | |
| return False | |
| def generate_text( | |
| self, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| do_sample: bool = True | |
| ) -> Dict[str, Any]: | |
| """ํ ์คํธ ์์ฑ""" | |
| if not self.model_loaded or self.model is None or self.tokenizer is None: | |
| raise RuntimeError("๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค") | |
| start_time = time.time() | |
| try: | |
| # ์ ๋ ฅ ํ ํฌ๋์ด์ง | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| # ํ ์คํธ ์์ฑ | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # ๊ฒฐ๊ณผ ๋์ฝ๋ฉ | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # ์๋ณธ ํ๋กฌํํธ ์ ๊ฑฐ | |
| if prompt in generated_text: | |
| generated_text = generated_text.replace(prompt, "").strip() | |
| processing_time = time.time() - start_time | |
| return { | |
| "generated_text": generated_text, | |
| "processing_time": processing_time, | |
| "model_name": "Lily LLM (Mistral-7B)" | |
| } | |
| except Exception as e: | |
| logger.error(f"ํ ์คํธ ์์ฑ ์ค๋ฅ: {e}") | |
| raise RuntimeError(f"ํ ์คํธ ์์ฑ ์คํจ: {str(e)}") | |
| def format_prompt(self, instruction: str, input_text: str = "") -> str: | |
| """ํ๋กฌํํธ ํฌ๋งทํ (Alpaca ํ์)""" | |
| if input_text: | |
| return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" | |
| else: | |
| return f"### Instruction:\n{instruction}\n\n### Response:\n" | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """๋ชจ๋ธ ์ ๋ณด ๋ฐํ""" | |
| return { | |
| "model_name": "Lily LLM", | |
| "base_model": self.model_name, | |
| "fine_tuned": True, | |
| "loaded": self.model_loaded, | |
| "device": str(next(self.model.parameters()).device) if self.model else None | |
| } | |
| def unload_model(self): | |
| """๋ชจ๋ธ ์ธ๋ก๋ (๋ฉ๋ชจ๋ฆฌ ํด์ )""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.tokenizer is not None: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| self.model_loaded = False | |
| # GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("๋ชจ๋ธ ์ธ๋ก๋ ์๋ฃ") | |
| def create_model_manager() -> LilyModelManager: | |
| """๋ชจ๋ธ ๋งค๋์ ์์ฑ""" | |
| return LilyModelManager() | |
| def test_model_generation(model_manager: LilyModelManager) -> bool: | |
| """๋ชจ๋ธ ์์ฑ ํ ์คํธ""" | |
| try: | |
| test_prompts = [ | |
| "๊ฐ๋จํ ์๊ธฐ์๊ฐ๋ฅผ ํด์ฃผ์ธ์", | |
| "์ค๋ ๊ธฐ๋ถ์ด ์ฐ์ธํด์", | |
| "ํ๋ก๊ทธ๋๋ฐ์ ๋ํด ์ค๋ช ํด์ฃผ์ธ์" | |
| ] | |
| for prompt in test_prompts: | |
| formatted_prompt = model_manager.format_prompt(prompt) | |
| result = model_manager.generate_text(formatted_prompt, max_length=50) | |
| logger.info(f"ํ ์คํธ ํ๋กฌํํธ: {prompt}") | |
| logger.info(f"์์ฑ๋ ํ ์คํธ: {result['generated_text']}") | |
| logger.info(f"์ฒ๋ฆฌ ์๊ฐ: {result['processing_time']:.2f}์ด") | |
| logger.info("-" * 50) | |
| return True | |
| except Exception as e: | |
| logger.error(f"๋ชจ๋ธ ํ ์คํธ ์คํจ: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| # ํ ์คํธ ์คํ | |
| logging.basicConfig(level=logging.INFO) | |
| manager = create_model_manager() | |
| if manager.load_model(): | |
| test_model_generation(manager) | |
| else: | |
| logger.error("๋ชจ๋ธ ๋ก๋ฉ ์คํจ") |