File size: 6,360 Bytes
58c4fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""

CRANE AI - Temel MicroModule Sınıfı

"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import os
import logging
import asyncio
from threading import Lock

logger = logging.getLogger(__name__)

class BaseMicroModule(ABC):
    """Tüm MicroModule'lar için temel sınıf"""
    
    def __init__(self, model_id: str, config: Dict[str, Any]):
        self.model_id = model_id
        self.config = config
        self.device = config.get("device", "cpu")
        self.max_tokens = config.get("max_tokens", 1024)
        self.temperature = config.get("temperature", 0.7)
        self.priority = config.get("priority", 1)
        
        # Model ve tokenizer
        self.model = None
        self.tokenizer = None
        self.is_loaded = False
        self.load_lock = Lock()
        
        # İstatistikler
        self.request_count = 0
        self.total_tokens = 0
        self.avg_response_time = 0
        
    async def load_model(self):
        """Modeli yükler"""
        if self.is_loaded:
            return
            
        with self.load_lock:
            if self.is_loaded:
                return
                
            try:
                logger.info(f"Loading model: {self.model_id}")
                
                # Tokenizer yükleme
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_id,
                    trust_remote_code=True,
                    token=self.config.get("hf_token")
                )
                
                # Model yükleme
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_id,
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
                    device_map="auto" if self.device != "cpu" else None,
                    token=self.config.get("hf_token")
                )
                
                # LoRA adaptörü kontrolü
                adapter_dir = os.path.join("model_cache", self.model_id.replace("/", "_"), "adapter")
                if os.path.isdir(adapter_dir):
                    try:
                        self.model = PeftModel.from_pretrained(self.model, adapter_dir, is_trainable=False)
                        self.model = self.model.merge_and_unload()
                        logger.info(f"LoRA adaptörü yüklendi: {adapter_dir}")
                    except Exception as adp_err:
                        logger.warning(f"Adaptör yüklenemedi ({adapter_dir}): {adp_err}")

                # Pad token ayarı
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                
                self.is_loaded = True
                logger.info(f"Model loaded successfully: {self.model_id}")
                
            except Exception as e:
                logger.error(f"Error loading model {self.model_id}: {str(e)}")
                raise
    
    @abstractmethod
    def can_handle(self, query: str, context: Dict[str, Any]) -> float:
        """Bu modülün sorguyu ne kadar iyi işleyebileceğini belirler (0-1)"""
        pass
    
    @abstractmethod
    async def process(self, query: str, context: Dict[str, Any]) -> Dict[str, Any]:
        """Ana işleme fonksiyonu"""
        pass
    
    async def generate_response(self, prompt: str, **kwargs) -> str:
        """Metin üretimi"""
        if not self.is_loaded:
            await self.load_model()
        
        try:
            # Tokenlara çevir
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                max_length=self.max_tokens,
                truncation=True,
                padding=True
            )

            # Tenzile cihaz aktarımı
            if self.device != "cpu":
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Üretim parametreleri
            generation_config = {
                "max_new_tokens": kwargs.get("max_tokens", self.max_tokens),
                "temperature": kwargs.get("temperature", self.temperature),
                "do_sample": True,
                "top_p": 0.9,
                "top_k": 50,
                "pad_token_id": self.tokenizer.pad_token_id,
                "eos_token_id": self.tokenizer.eos_token_id,
                "no_repeat_ngram_size": 3
            }
            
            # Üretim
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    **generation_config
                )
            
            # Metne çevir
            response = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            
            # İstatistikleri güncelle
            self.request_count += 1
            self.total_tokens += len(outputs[0])
            
            return response.strip()
            
        except Exception as e:
            logger.error(f"Generation error in {self.model_id}: {str(e)}")
            raise
    
    def get_stats(self) -> Dict[str, Any]:
        """Modül istatistiklerini döndürür"""
        return {
            "model_id": self.model_id,
            "is_loaded": self.is_loaded,
            "request_count": self.request_count,
            "total_tokens": self.total_tokens,
            "avg_response_time": self.avg_response_time,
            "priority": self.priority
        }
    
    def unload_model(self):
        """Modeli bellekten kaldırır"""
        if self.model:
            del self.model
            self.model = None
        if self.tokenizer:
            del self.tokenizer
            self.tokenizer = None
        self.is_loaded = False
        
        # GPU belleğini temizle
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        logger.info(f"Model unloaded: {self.model_id}")