lily_fast_api / lily_llm_utils /model_utils.py
gbrabbit's picture
Fresh start for HF Spaces deployment
526927a
#!/usr/bin/env python3
"""
Lily LLM ๋ชจ๋ธ ์œ ํ‹ธ๋ฆฌํ‹ฐ
๋ชจ๋ธ ๋กœ๋”ฉ, ์ถ”๋ก , ์ตœ์ ํ™” ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
"""
import torch
import logging
from typing import Optional, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import time
logger = logging.getLogger(__name__)
class LilyModelManager:
"""Lily LLM ๋ชจ๋ธ ๊ด€๋ฆฌ์ž"""
def __init__(self):
self.model = None
self.tokenizer = None
self.model_loaded = False
self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
self.lora_path = "hearth_llm_model"
def load_model(self, device: str = "cpu") -> bool:
"""๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ"""
try:
logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# ๋ชจ๋ธ ๋กœ๋“œ
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float32,
device_map=device,
low_cpu_mem_usage=True
)
# LoRA ์–ด๋Œ‘ํ„ฐ ๋กœ๋“œ (ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ)
try:
self.model = PeftModel.from_pretrained(self.model, self.lora_path)
logger.info("LoRA ์–ด๋Œ‘ํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต")
except Exception as e:
logger.warning(f"LoRA ์–ด๋Œ‘ํ„ฐ ๋กœ๋“œ ์‹คํŒจ, ๊ธฐ๋ณธ ๋ชจ๋ธ ์‚ฌ์šฉ: {e}")
self.model_loaded = True
logger.info("โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
return True
except Exception as e:
logger.error(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
self.model_loaded = False
return False
def generate_text(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True
) -> Dict[str, Any]:
"""ํ…์ŠคํŠธ ์ƒ์„ฑ"""
if not self.model_loaded or self.model is None or self.tokenizer is None:
raise RuntimeError("๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค")
start_time = time.time()
try:
# ์ž…๋ ฅ ํ† ํฌ๋‚˜์ด์ง•
inputs = self.tokenizer(prompt, return_tensors="pt")
# ํ…์ŠคํŠธ ์ƒ์„ฑ
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id
)
# ๊ฒฐ๊ณผ ๋””์ฝ”๋”ฉ
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# ์›๋ณธ ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ
if prompt in generated_text:
generated_text = generated_text.replace(prompt, "").strip()
processing_time = time.time() - start_time
return {
"generated_text": generated_text,
"processing_time": processing_time,
"model_name": "Lily LLM (Mistral-7B)"
}
except Exception as e:
logger.error(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
raise RuntimeError(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}")
def format_prompt(self, instruction: str, input_text: str = "") -> str:
"""ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ… (Alpaca ํ˜•์‹)"""
if input_text:
return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
else:
return f"### Instruction:\n{instruction}\n\n### Response:\n"
def get_model_info(self) -> Dict[str, Any]:
"""๋ชจ๋ธ ์ •๋ณด ๋ฐ˜ํ™˜"""
return {
"model_name": "Lily LLM",
"base_model": self.model_name,
"fine_tuned": True,
"loaded": self.model_loaded,
"device": str(next(self.model.parameters()).device) if self.model else None
}
def unload_model(self):
"""๋ชจ๋ธ ์–ธ๋กœ๋“œ (๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ)"""
if self.model is not None:
del self.model
self.model = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.model_loaded = False
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("๋ชจ๋ธ ์–ธ๋กœ๋“œ ์™„๋ฃŒ")
def create_model_manager() -> LilyModelManager:
"""๋ชจ๋ธ ๋งค๋‹ˆ์ € ์ƒ์„ฑ"""
return LilyModelManager()
def test_model_generation(model_manager: LilyModelManager) -> bool:
"""๋ชจ๋ธ ์ƒ์„ฑ ํ…Œ์ŠคํŠธ"""
try:
test_prompts = [
"๊ฐ„๋‹จํ•œ ์ž๊ธฐ์†Œ๊ฐœ๋ฅผ ํ•ด์ฃผ์„ธ์š”",
"์˜ค๋Š˜ ๊ธฐ๋ถ„์ด ์šฐ์šธํ•ด์š”",
"ํ”„๋กœ๊ทธ๋ž˜๋ฐ์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”"
]
for prompt in test_prompts:
formatted_prompt = model_manager.format_prompt(prompt)
result = model_manager.generate_text(formatted_prompt, max_length=50)
logger.info(f"ํ…Œ์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ: {prompt}")
logger.info(f"์ƒ์„ฑ๋œ ํ…์ŠคํŠธ: {result['generated_text']}")
logger.info(f"์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {result['processing_time']:.2f}์ดˆ")
logger.info("-" * 50)
return True
except Exception as e:
logger.error(f"๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
return False
if __name__ == "__main__":
# ํ…Œ์ŠคํŠธ ์‹คํ–‰
logging.basicConfig(level=logging.INFO)
manager = create_model_manager()
if manager.load_model():
test_model_generation(manager)
else:
logger.error("๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ")