|
|
"""
|
|
|
CRANE AI - Temel MicroModule Sınıfı
|
|
|
"""
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
from peft import PeftModel, PeftConfig
|
|
|
import os
|
|
|
import logging
|
|
|
import asyncio
|
|
|
from threading import Lock
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseMicroModule(ABC):
|
|
|
"""Tüm MicroModule'lar için temel sınıf"""
|
|
|
|
|
|
def __init__(self, model_id: str, config: Dict[str, Any]):
|
|
|
self.model_id = model_id
|
|
|
self.config = config
|
|
|
self.device = config.get("device", "cpu")
|
|
|
self.max_tokens = config.get("max_tokens", 1024)
|
|
|
self.temperature = config.get("temperature", 0.7)
|
|
|
self.priority = config.get("priority", 1)
|
|
|
|
|
|
|
|
|
self.model = None
|
|
|
self.tokenizer = None
|
|
|
self.is_loaded = False
|
|
|
self.load_lock = Lock()
|
|
|
|
|
|
|
|
|
self.request_count = 0
|
|
|
self.total_tokens = 0
|
|
|
self.avg_response_time = 0
|
|
|
|
|
|
async def load_model(self):
|
|
|
"""Modeli yükler"""
|
|
|
if self.is_loaded:
|
|
|
return
|
|
|
|
|
|
with self.load_lock:
|
|
|
if self.is_loaded:
|
|
|
return
|
|
|
|
|
|
try:
|
|
|
logger.info(f"Loading model: {self.model_id}")
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
self.model_id,
|
|
|
trust_remote_code=True,
|
|
|
token=self.config.get("hf_token")
|
|
|
)
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
self.model_id,
|
|
|
trust_remote_code=True,
|
|
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
|
|
device_map="auto" if self.device != "cpu" else None,
|
|
|
token=self.config.get("hf_token")
|
|
|
)
|
|
|
|
|
|
|
|
|
adapter_dir = os.path.join("model_cache", self.model_id.replace("/", "_"), "adapter")
|
|
|
if os.path.isdir(adapter_dir):
|
|
|
try:
|
|
|
self.model = PeftModel.from_pretrained(self.model, adapter_dir, is_trainable=False)
|
|
|
self.model = self.model.merge_and_unload()
|
|
|
logger.info(f"LoRA adaptörü yüklendi: {adapter_dir}")
|
|
|
except Exception as adp_err:
|
|
|
logger.warning(f"Adaptör yüklenemedi ({adapter_dir}): {adp_err}")
|
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
self.is_loaded = True
|
|
|
logger.info(f"Model loaded successfully: {self.model_id}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading model {self.model_id}: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
@abstractmethod
|
|
|
def can_handle(self, query: str, context: Dict[str, Any]) -> float:
|
|
|
"""Bu modülün sorguyu ne kadar iyi işleyebileceğini belirler (0-1)"""
|
|
|
pass
|
|
|
|
|
|
@abstractmethod
|
|
|
async def process(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""Ana işleme fonksiyonu"""
|
|
|
pass
|
|
|
|
|
|
async def generate_response(self, prompt: str, **kwargs) -> str:
|
|
|
"""Metin üretimi"""
|
|
|
if not self.is_loaded:
|
|
|
await self.load_model()
|
|
|
|
|
|
try:
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
prompt,
|
|
|
return_tensors="pt",
|
|
|
max_length=self.max_tokens,
|
|
|
truncation=True,
|
|
|
padding=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.device != "cpu":
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
generation_config = {
|
|
|
"max_new_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
|
"do_sample": True,
|
|
|
"top_p": 0.9,
|
|
|
"top_k": 50,
|
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
|
"no_repeat_ngram_size": 3
|
|
|
}
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.generate(
|
|
|
**inputs,
|
|
|
**generation_config
|
|
|
)
|
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(
|
|
|
outputs[0][inputs["input_ids"].shape[1]:],
|
|
|
skip_special_tokens=True
|
|
|
)
|
|
|
|
|
|
|
|
|
self.request_count += 1
|
|
|
self.total_tokens += len(outputs[0])
|
|
|
|
|
|
return response.strip()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Generation error in {self.model_id}: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Modül istatistiklerini döndürür"""
|
|
|
return {
|
|
|
"model_id": self.model_id,
|
|
|
"is_loaded": self.is_loaded,
|
|
|
"request_count": self.request_count,
|
|
|
"total_tokens": self.total_tokens,
|
|
|
"avg_response_time": self.avg_response_time,
|
|
|
"priority": self.priority
|
|
|
}
|
|
|
|
|
|
def unload_model(self):
|
|
|
"""Modeli bellekten kaldırır"""
|
|
|
if self.model:
|
|
|
del self.model
|
|
|
self.model = None
|
|
|
if self.tokenizer:
|
|
|
del self.tokenizer
|
|
|
self.tokenizer = None
|
|
|
self.is_loaded = False
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
logger.info(f"Model unloaded: {self.model_id}") |