|
|
"""
|
|
|
CRANE AI - Token Capsule Layer
|
|
|
"""
|
|
|
|
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from transformers import AutoTokenizer
|
|
|
import logging
|
|
|
from dataclasses import dataclass
|
|
|
import asyncio
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
|
class TokenCapsule:
|
|
|
"""Token kapsülü veri yapısı"""
|
|
|
tokens: List[int]
|
|
|
attention_mask: List[int]
|
|
|
token_type_ids: List[int]
|
|
|
embeddings: Optional[torch.Tensor] = None
|
|
|
metadata: Dict[str, Any] = None
|
|
|
|
|
|
class TokenCapsuleLayer:
|
|
|
"""Token işleme ve optimizasyon katmanı"""
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
|
self.config = config
|
|
|
self.max_length = config.get("max_length", 2048)
|
|
|
self.device = config.get("device", "cpu")
|
|
|
|
|
|
|
|
|
self.token_cache = {}
|
|
|
self.cache_size = config.get("cache_size", 1000)
|
|
|
|
|
|
|
|
|
self.token_stats = {
|
|
|
"total_processed": 0,
|
|
|
"cache_hits": 0,
|
|
|
"cache_misses": 0,
|
|
|
"avg_token_length": 0
|
|
|
}
|
|
|
|
|
|
|
|
|
self.tokenizer_pool = {}
|
|
|
|
|
|
async def process_input(self, text: str, model_id: str, context: Dict[str, Any] = None) -> TokenCapsule:
|
|
|
"""Giriş metnini token kapsülüne çevirir"""
|
|
|
try:
|
|
|
|
|
|
cache_key = f"{model_id}:{hash(text)}"
|
|
|
if cache_key in self.token_cache:
|
|
|
self.token_stats["cache_hits"] += 1
|
|
|
return self.token_cache[cache_key]
|
|
|
|
|
|
|
|
|
tokenizer = await self._get_tokenizer(model_id)
|
|
|
|
|
|
|
|
|
encoding = tokenizer(
|
|
|
text,
|
|
|
max_length=self.max_length,
|
|
|
padding=True,
|
|
|
truncation=True,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
|
|
|
capsule = TokenCapsule(
|
|
|
tokens=encoding["input_ids"].squeeze().tolist(),
|
|
|
attention_mask=encoding["attention_mask"].squeeze().tolist(),
|
|
|
token_type_ids=encoding.get("token_type_ids", []).squeeze().tolist() if encoding.get("token_type_ids") is not None else [],
|
|
|
metadata={
|
|
|
"model_id": model_id,
|
|
|
"original_text": text[:100],
|
|
|
"token_count": len(encoding["input_ids"].squeeze()),
|
|
|
"context": context
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
self._add_to_cache(cache_key, capsule)
|
|
|
|
|
|
|
|
|
self.token_stats["total_processed"] += 1
|
|
|
self.token_stats["cache_misses"] += 1
|
|
|
self._update_avg_length(len(capsule.tokens))
|
|
|
|
|
|
return capsule
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Token processing error: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
async def optimize_tokens(self, capsule: TokenCapsule, optimization_type: str = "standard") -> TokenCapsule:
|
|
|
"""Token optimizasyonu yapar"""
|
|
|
try:
|
|
|
if optimization_type == "compress":
|
|
|
return await self._compress_tokens(capsule)
|
|
|
elif optimization_type == "expand":
|
|
|
return await self._expand_tokens(capsule)
|
|
|
elif optimization_type == "filter":
|
|
|
return await self._filter_tokens(capsule)
|
|
|
else:
|
|
|
return capsule
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Token optimization error: {str(e)}")
|
|
|
return capsule
|
|
|
|
|
|
async def merge_capsules(self, capsules: List[TokenCapsule], strategy: str = "concat") -> TokenCapsule:
|
|
|
"""Birden fazla kapsülü birleştirir"""
|
|
|
try:
|
|
|
if strategy == "concat":
|
|
|
return await self._concat_capsules(capsules)
|
|
|
elif strategy == "interleave":
|
|
|
return await self._interleave_capsules(capsules)
|
|
|
elif strategy == "priority":
|
|
|
return await self._priority_merge_capsules(capsules)
|
|
|
else:
|
|
|
return capsules[0] if capsules else None
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Capsule merging error: {str(e)}")
|
|
|
return capsules[0] if capsules else None
|
|
|
|
|
|
async def extract_embeddings(self, capsule: TokenCapsule, model: Any) -> TokenCapsule:
|
|
|
"""Token embedding'lerini çıkarır"""
|
|
|
try:
|
|
|
if capsule.embeddings is not None:
|
|
|
return capsule
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor([capsule.tokens]).to(self.device)
|
|
|
attention_mask = torch.tensor([capsule.attention_mask]).to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
output_hidden_states=True
|
|
|
)
|
|
|
|
|
|
|
|
|
embeddings = outputs.hidden_states[-1]
|
|
|
capsule.embeddings = embeddings.squeeze()
|
|
|
|
|
|
return capsule
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Embedding extraction error: {str(e)}")
|
|
|
return capsule
|
|
|
|
|
|
async def _get_tokenizer(self, model_id: str) -> AutoTokenizer:
|
|
|
"""Model için tokenizer alır"""
|
|
|
if model_id not in self.tokenizer_pool:
|
|
|
try:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
model_id,
|
|
|
trust_remote_code=True,
|
|
|
token=self.config.get("hf_token")
|
|
|
)
|
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
self.tokenizer_pool[model_id] = tokenizer
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Tokenizer loading error for {model_id}: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
return self.tokenizer_pool[model_id]
|
|
|
|
|
|
async def _compress_tokens(self, capsule: TokenCapsule) -> TokenCapsule:
|
|
|
"""Token sıkıştırma"""
|
|
|
|
|
|
important_tokens = []
|
|
|
attention_mask = []
|
|
|
|
|
|
for i, token in enumerate(capsule.tokens):
|
|
|
|
|
|
if token in [0, 1, 2, 3]:
|
|
|
important_tokens.append(token)
|
|
|
attention_mask.append(capsule.attention_mask[i])
|
|
|
|
|
|
elif token not in important_tokens[-5:]:
|
|
|
important_tokens.append(token)
|
|
|
attention_mask.append(capsule.attention_mask[i])
|
|
|
|
|
|
compressed_capsule = TokenCapsule(
|
|
|
tokens=important_tokens,
|
|
|
attention_mask=attention_mask,
|
|
|
token_type_ids=capsule.token_type_ids[:len(important_tokens)],
|
|
|
embeddings=capsule.embeddings,
|
|
|
metadata={**capsule.metadata, "compressed": True}
|
|
|
)
|
|
|
|
|
|
return compressed_capsule
|
|
|
|
|
|
async def _expand_tokens(self, capsule: TokenCapsule) -> TokenCapsule:
|
|
|
"""Token genişletme"""
|
|
|
|
|
|
expanded_tokens = [1] + capsule.tokens + [2]
|
|
|
expanded_attention = [1] + capsule.attention_mask + [1]
|
|
|
|
|
|
expanded_capsule = TokenCapsule(
|
|
|
tokens=expanded_tokens,
|
|
|
attention_mask=expanded_attention,
|
|
|
token_type_ids=capsule.token_type_ids,
|
|
|
embeddings=capsule.embeddings,
|
|
|
metadata={**capsule.metadata, "expanded": True}
|
|
|
)
|
|
|
|
|
|
return expanded_capsule
|
|
|
|
|
|
async def _filter_tokens(self, capsule: TokenCapsule) -> TokenCapsule:
|
|
|
"""Token filtreleme"""
|
|
|
|
|
|
filtered_tokens = []
|
|
|
filtered_attention = []
|
|
|
|
|
|
for i, token in enumerate(capsule.tokens):
|
|
|
|
|
|
if token != 0:
|
|
|
filtered_tokens.append(token)
|
|
|
filtered_attention.append(capsule.attention_mask[i])
|
|
|
|
|
|
filtered_capsule = TokenCapsule(
|
|
|
tokens=filtered_tokens,
|
|
|
attention_mask=filtered_attention,
|
|
|
token_type_ids=capsule.token_type_ids[:len(filtered_tokens)],
|
|
|
embeddings=capsule.embeddings,
|
|
|
metadata={**capsule.metadata, "filtered": True}
|
|
|
)
|
|
|
|
|
|
return filtered_capsule
|
|
|
|
|
|
async def _concat_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule:
|
|
|
"""Kapsülleri ardışık olarak birleştirir"""
|
|
|
if not capsules:
|
|
|
return None
|
|
|
|
|
|
merged_tokens = []
|
|
|
merged_attention = []
|
|
|
merged_metadata = {"merged": True, "source_count": len(capsules)}
|
|
|
|
|
|
for capsule in capsules:
|
|
|
merged_tokens.extend(capsule.tokens)
|
|
|
merged_attention.extend(capsule.attention_mask)
|
|
|
|
|
|
|
|
|
if capsule.metadata:
|
|
|
for key, value in capsule.metadata.items():
|
|
|
if key not in merged_metadata:
|
|
|
merged_metadata[key] = []
|
|
|
merged_metadata[key].append(value)
|
|
|
|
|
|
|
|
|
if len(merged_tokens) > self.max_length:
|
|
|
merged_tokens = merged_tokens[:self.max_length]
|
|
|
merged_attention = merged_attention[:self.max_length]
|
|
|
|
|
|
return TokenCapsule(
|
|
|
tokens=merged_tokens,
|
|
|
attention_mask=merged_attention,
|
|
|
token_type_ids=[],
|
|
|
metadata=merged_metadata
|
|
|
)
|
|
|
|
|
|
async def _interleave_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule:
|
|
|
"""Kapsülleri aralarında birleştirir"""
|
|
|
|
|
|
return await self._concat_capsules(capsules)
|
|
|
|
|
|
async def _priority_merge_capsules(self, capsules: List[TokenCapsule]) -> TokenCapsule:
|
|
|
"""Öncelik sırasına göre birleştirir"""
|
|
|
|
|
|
sorted_capsules = sorted(capsules, key=lambda x: x.metadata.get("priority", 0), reverse=True)
|
|
|
return await self._concat_capsules(sorted_capsules)
|
|
|
|
|
|
def _add_to_cache(self, key: str, capsule: TokenCapsule):
|
|
|
"""Cache'e ekler"""
|
|
|
if len(self.token_cache) >= self.cache_size:
|
|
|
|
|
|
oldest_key = next(iter(self.token_cache))
|
|
|
del self.token_cache[oldest_key]
|
|
|
|
|
|
self.token_cache[key] = capsule
|
|
|
|
|
|
def _update_avg_length(self, length: int):
|
|
|
"""Ortalama token uzunluğunu günceller"""
|
|
|
current_avg = self.token_stats["avg_token_length"]
|
|
|
total_processed = self.token_stats["total_processed"]
|
|
|
|
|
|
|
|
|
new_avg = ((current_avg * (total_processed - 1)) + length) / total_processed
|
|
|
self.token_stats["avg_token_length"] = new_avg
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Token layer istatistikleri"""
|
|
|
cache_hit_rate = self.token_stats["cache_hits"] / max(self.token_stats["total_processed"], 1)
|
|
|
|
|
|
return {
|
|
|
"total_processed": self.token_stats["total_processed"],
|
|
|
"cache_hit_rate": cache_hit_rate,
|
|
|
"cache_size": len(self.token_cache),
|
|
|
"avg_token_length": self.token_stats["avg_token_length"],
|
|
|
"tokenizer_pool_size": len(self.tokenizer_pool)
|
|
|
}
|
|
|
|
|
|
def clear_cache(self):
|
|
|
"""Cache'i temizler"""
|
|
|
self.token_cache.clear()
|
|
|
logger.info("Token cache cleared") |