""" CRANE AI - Temel MicroModule Sınıfı """ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel, PeftConfig import os import logging import asyncio from threading import Lock logger = logging.getLogger(__name__) class BaseMicroModule(ABC): """Tüm MicroModule'lar için temel sınıf""" def __init__(self, model_id: str, config: Dict[str, Any]): self.model_id = model_id self.config = config self.device = config.get("device", "cpu") self.max_tokens = config.get("max_tokens", 1024) self.temperature = config.get("temperature", 0.7) self.priority = config.get("priority", 1) # Model ve tokenizer self.model = None self.tokenizer = None self.is_loaded = False self.load_lock = Lock() # İstatistikler self.request_count = 0 self.total_tokens = 0 self.avg_response_time = 0 async def load_model(self): """Modeli yükler""" if self.is_loaded: return with self.load_lock: if self.is_loaded: return try: logger.info(f"Loading model: {self.model_id}") # Tokenizer yükleme self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, trust_remote_code=True, token=self.config.get("hf_token") ) # Model yükleme self.model = AutoModelForCausalLM.from_pretrained( self.model_id, trust_remote_code=True, torch_dtype=torch.float16 if self.device != "cpu" else torch.float32, device_map="auto" if self.device != "cpu" else None, token=self.config.get("hf_token") ) # LoRA adaptörü kontrolü adapter_dir = os.path.join("model_cache", self.model_id.replace("/", "_"), "adapter") if os.path.isdir(adapter_dir): try: self.model = PeftModel.from_pretrained(self.model, adapter_dir, is_trainable=False) self.model = self.model.merge_and_unload() logger.info(f"LoRA adaptörü yüklendi: {adapter_dir}") except Exception as adp_err: logger.warning(f"Adaptör yüklenemedi ({adapter_dir}): {adp_err}") # Pad token ayarı if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.is_loaded = True logger.info(f"Model loaded successfully: {self.model_id}") except Exception as e: logger.error(f"Error loading model {self.model_id}: {str(e)}") raise @abstractmethod def can_handle(self, query: str, context: Dict[str, Any]) -> float: """Bu modülün sorguyu ne kadar iyi işleyebileceğini belirler (0-1)""" pass @abstractmethod async def process(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]: """Ana işleme fonksiyonu""" pass async def generate_response(self, prompt: str, **kwargs) -> str: """Metin üretimi""" if not self.is_loaded: await self.load_model() try: # Tokenlara çevir inputs = self.tokenizer( prompt, return_tensors="pt", max_length=self.max_tokens, truncation=True, padding=True ) # Tenzile cihaz aktarımı if self.device != "cpu": inputs = {k: v.to(self.device) for k, v in inputs.items()} # Üretim parametreleri generation_config = { "max_new_tokens": kwargs.get("max_tokens", self.max_tokens), "temperature": kwargs.get("temperature", self.temperature), "do_sample": True, "top_p": 0.9, "top_k": 50, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "no_repeat_ngram_size": 3 } # Üretim with torch.no_grad(): outputs = self.model.generate( **inputs, **generation_config ) # Metne çevir response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) # İstatistikleri güncelle self.request_count += 1 self.total_tokens += len(outputs[0]) return response.strip() except Exception as e: logger.error(f"Generation error in {self.model_id}: {str(e)}") raise def get_stats(self) -> Dict[str, Any]: """Modül istatistiklerini döndürür""" return { "model_id": self.model_id, "is_loaded": self.is_loaded, "request_count": self.request_count, "total_tokens": self.total_tokens, "avg_response_time": self.avg_response_time, "priority": self.priority } def unload_model(self): """Modeli bellekten kaldırır""" if self.model: del self.model self.model = None if self.tokenizer: del self.tokenizer self.tokenizer = None self.is_loaded = False # GPU belleğini temizle if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"Model unloaded: {self.model_id}")