""" CRANE AI - Token Capsule Layer """ from typing import Dict, Any, List, Optional, Tuple import torch import numpy as np from transformers import AutoTokenizer import logging from dataclasses import dataclass import asyncio logger = logging.getLogger(__name__) @dataclass class TokenCapsule: """Token kapsülü veri yapısı""" tokens: List[int] attention_mask: List[int] token_type_ids: List[int] embeddings: Optional[torch.Tensor] = None metadata: Dict[str, Any] = None class TokenCapsuleLayer: """Token işleme ve optimizasyon katmanı""" def __init__(self, config: Dict[str, Any]): self.config = config self.max_length = config.get("max_length", 2048) self.device = config.get("device", "cpu") # Token cache self.token_cache = {} self.cache_size = config.get("cache_size", 1000) # Token istatistikleri self.token_stats = { "total_processed": 0, "cache_hits": 0, "cache_misses": 0, "avg_token_length": 0 } # Tokenizer havuzu self.tokenizer_pool = {} async def process_input(self, text: str, model_id: str, context: Dict[str, Any] = None) -> TokenCapsule: """Giriş metnini token kapsülüne çevirir""" try: # Cache kontrolü cache_key = f"{model_id}:{hash(text)}" if cache_key in self.token_cache: self.token_stats["cache_hits"] += 1 return self.token_cache[cache_key] # Tokenizer al tokenizer = await self._get_tokenizer(model_id) # Tokenize et encoding = tokenizer( text, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" ) # Token kapsülü oluştur capsule = TokenCapsule( tokens=encoding["input_ids"].squeeze().tolist(), attention_mask=encoding["attention_mask"].squeeze().tolist(), token_type_ids=encoding.get("token_type_ids", []).squeeze().tolist() if encoding.get("token_type_ids") is not None else [], metadata={ "model_id": model_id, "original_text": text[:100], # İlk 100 karakter "token_count": len(encoding["input_ids"].squeeze()), "context": context } ) # Cache'e ekle self._add_to_cache(cache_key, capsule) # İstatistikleri güncelle self.token_stats["total_processed"] += 1 self.token_stats["cache_misses"] += 1 self._update_avg_length(len(capsule.tokens)) return capsule except Exception as e: logger.error(f"Token processing error: {str(e)}") raise async def optimize_tokens(self, capsule: TokenCapsule, optimization_type: str = "standard") -> TokenCapsule: """Token optimizasyonu yapar""" try: if optimization_type == "compress": return await self._compress_tokens(capsule) elif optimization_type == "expand": return await self._expand_tokens(capsule) elif optimization_type == "filter": return await self._filter_tokens(capsule) else: return capsule except Exception as e: logger.error(f"Token optimization error: {str(e)}") return capsule async def merge_capsules(self, capsules: List[TokenCapsule], strategy: str = "concat") -> TokenCapsule: """Birden fazla kapsülü birleştirir""" try: if strategy == "concat": return await self._concat_capsules(capsules) elif strategy == "interleave": return await self._interleave_capsules(capsules) elif strategy == "priority": return await self._priority_merge_capsules(capsules) else: return capsules[0] if capsules else None except Exception as e: logger.error(f"Capsule merging error: {str(e)}") return capsules[0] if capsules else None async def extract_embeddings(self, capsule: TokenCapsule, model: Any) -> TokenCapsule: """Token embedding'lerini çıkarır""" try: if capsule.embeddings is not None: return capsule # Model'den embedding'leri al input_ids = torch.tensor([capsule.tokens]).to(self.device) attention_mask = torch.tensor([capsule.attention_mask]).to(self.device) with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) # Son katman hidden state'ini al embeddings = outputs.hidden_states[-1] capsule.embeddings = embeddings.squeeze() return capsule except Exception as e: logger.error(f"Embedding extraction error: {str(e)}") return capsule async def _get_tokenizer(self, model_id: str) -> AutoTokenizer: """Model için tokenizer alır""" if model_id not in self.tokenizer_pool: try: tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, token=self.config.get("hf_token") ) # Pad token ayarı if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token self.tokenizer_pool[model_id] = tokenizer except Exception as e: logger.error(f"Tokenizer loading error for {model_id}: {str(e)}") raise return self.tokenizer_pool[model_id] async def _compress_tokens(self, capsule: TokenCapsule) -> TokenCapsule: """Token sıkıştırma""" # Önemli token'ları tespit et ve gereksizleri çıkar important_tokens = [] attention_mask = [] for i, token in enumerate(capsule.tokens): # Özel tokenları koru if token in [0, 1, 2, 3]: # [PAD], [UNK], [CLS], [SEP] important_tokens.append(token) attention_mask.append(capsule.attention_mask[i]) # Çok tekrarlanan tokenları atla elif token not in important_tokens[-5:]: # Son 5 token içinde yoksa important_tokens.append(token) attention_mask.append(capsule.attention_mask[i]) compressed_capsule = TokenCapsule( tokens=important_tokens, attention_mask=attention_mask, token_type_ids=capsule.token_type_ids[:len(important_tokens)], embeddings=capsule.embeddings, metadata={**capsule.metadata, "compressed": True} ) return compressed_capsule async def _expand_tokens(self, capsule: TokenCapsule) -> TokenCapsule: """Token genişletme""" # Context token'ları ekle expanded_tokens = [1] + capsule.tokens + [2] # [CLS] + tokens + [SEP] expanded_attention = [1] + capsule.attention_mask + [1] expanded_capsule = TokenCapsule( tokens=expanded_tokens, attention_mask=expanded_attention, token_type_ids=capsule.token_type_ids, embeddings=capsule.embeddings, metadata={**capsule.metadata, "expanded": True} ) return expanded_capsule async def _filter_tokens(self, capsule: TokenCapsule) -> TokenCapsule: """Token filtreleme""" # Gereksiz token'ları filtrele filtered_tokens = [] filtered_attention = [] for i, token in enumerate(capsule.tokens): # Padding token'ları atla if token != 0: filtered_tokens.append(token) filtered_attention.append(capsule.attention_mask[i]) filtered_capsule = TokenCapsule( tokens=filtered_tokens, attention_mask=filtered_attention, token_type_ids=capsule.token_type_ids[:len(filtered_tokens)], embeddings=capsule.embeddings, metadata={**capsule.metadata, "filtered": True} ) return filtered_capsule async def _concat_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule: """Kapsülleri ardışık olarak birleştirir""" if not capsules: return None merged_tokens = [] merged_attention = [] merged_metadata = {"merged": True, "source_count": len(capsules)} for capsule in capsules: merged_tokens.extend(capsule.tokens) merged_attention.extend(capsule.attention_mask) # Metadata birleştir if capsule.metadata: for key, value in capsule.metadata.items(): if key not in merged_metadata: merged_metadata[key] = [] merged_metadata[key].append(value) # Maksimum uzunluk kontrolü if len(merged_tokens) > self.max_length: merged_tokens = merged_tokens[:self.max_length] merged_attention = merged_attention[:self.max_length] return TokenCapsule( tokens=merged_tokens, attention_mask=merged_attention, token_type_ids=[], metadata=merged_metadata ) async def _interleave_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule: """Kapsülleri aralarında birleştirir""" # Implement interleaving logic return await self._concat_capsules(capsules) async def _priority_merge_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule: """Öncelik sırasına göre birleştirir""" # Öncelik skoruna göre sırala sorted_capsules = sorted(capsules, key=lambda x: x.metadata.get("priority", 0), reverse=True) return await self._concat_capsules(sorted_capsules) def _add_to_cache(self, key: str, capsule: TokenCapsule): """Cache'e ekler""" if len(self.token_cache) >= self.cache_size: # En eski entry'yi sil oldest_key = next(iter(self.token_cache)) del self.token_cache[oldest_key] self.token_cache[key] = capsule def _update_avg_length(self, length: int): """Ortalama token uzunluğunu günceller""" current_avg = self.token_stats["avg_token_length"] total_processed = self.token_stats["total_processed"] # Yeni ortalama hesapla new_avg = ((current_avg * (total_processed - 1)) + length) / total_processed self.token_stats["avg_token_length"] = new_avg def get_stats(self) -> Dict[str, Any]: """Token layer istatistikleri""" cache_hit_rate = self.token_stats["cache_hits"] / max(self.token_stats["total_processed"], 1) return { "total_processed": self.token_stats["total_processed"], "cache_hit_rate": cache_hit_rate, "cache_size": len(self.token_cache), "avg_token_length": self.token_stats["avg_token_length"], "tokenizer_pool_size": len(self.tokenizer_pool) } def clear_cache(self): """Cache'i temizler""" self.token_cache.clear() logger.info("Token cache cleared")