""" Unified utilities for handling transformers, tokenizers, and embeddings. """ import os import logging import torch from typing import Dict, Any, Optional, Union, List from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoModel # Import the new sentence transformer utilities from utils.sentence_transformer_utils import get_sentence_transformer as load_sentence_transformer logger = logging.getLogger(__name__) # Constants DEFAULT_SENTENCE_TRANSFORMER = "sentence-transformers/Wildnerve-tlm01-0.05Bx12" DEFAULT_TOKENIZER = "bert-base-uncased" FALLBACK_TOKENIZERS = ["bert-base-uncased", "gpt2", "roberta-base"] # Cache for loaded models to avoid reloading _model_cache = {} _tokenizer_cache = {} _sentence_transformer_cache = {} def get_sentence_transformer(model_name): try: from sentence_transformers import SentenceTransformer return SentenceTransformer(model_name) except Exception as e: logging.error(f"Failed to load sentence transformer {model_name}: {e}") logging.warning("Falling back to default model: Wildnerve-tlm01-0.05Bx12") from sentence_transformers import SentenceTransformer return SentenceTransformer("Wildnerve-tlm01-0.05Bx12") def get_tokenizer(model_name: str = "bert-base-uncased"): """Get a tokenizer with proper error handling""" try: from transformers import AutoTokenizer logger.info(f"Loading tokenizer: {model_name}") return AutoTokenizer.from_pretrained(model_name) except Exception as e: logger.error(f"Failed to load tokenizer {model_name}: {e}") # Return a minimal dummy tokenizer that won't break everything logger.warning("Using dummy tokenizer as fallback") class DummyTokenizer: def __init__(self): self.vocab_size = 30522 # BERT vocab size self.pad_token_id = 0 self.eos_token_id = 102 self.bos_token_id = 101 def __call__(self, text, **kwargs): """Convert text to a dict with dummy tensors""" import torch # Handle batch vs single input is_batch = isinstance(text, list) texts = text if is_batch else [text] # Create random but deterministic IDs based on text length input_ids = [] attention_mask = [] for t in texts: # Use text length to create deterministic pseudo-random sequence import hashlib hash_obj = hashlib.md5(t.encode()) seed = int(hash_obj.hexdigest(), 16) % 10000 import random random.seed(seed) # Get length or use max_length if provided max_length = kwargs.get("max_length", 128) length = min(len(t.split()), max_length) # Generate ids and mask ids = [self.bos_token_id] + [random.randint(1000, 30000) for _ in range(length-2)] + [self.eos_token_id] mask = [1] * len(ids) # Pad if needed if "padding" in kwargs: pad_length = max_length - len(ids) if pad_length > 0: ids.extend([self.pad_token_id] * pad_length) mask.extend([0] * pad_length) input_ids.append(torch.tensor(ids)) attention_mask.append(torch.tensor(mask)) # Stack tensors if "return_tensors" in kwargs and kwargs["return_tensors"] == "pt": if is_batch or len(texts) > 1: return { "input_ids": torch.stack(input_ids), "attention_mask": torch.stack(attention_mask) } else: return { "input_ids": input_ids[0].unsqueeze(0), "attention_mask": attention_mask[0].unsqueeze(0) } else: return { "input_ids": input_ids[0] if not is_batch and len(texts) == 1 else input_ids, "attention_mask": attention_mask[0] if not is_batch and len(texts) == 1 else attention_mask } def decode(self, token_ids, skip_special_tokens=True, **kwargs): """Convert token IDs back to text""" if isinstance(token_ids, (list, tuple)) and len(token_ids) > 0: return f"Decoded text from {len(token_ids)} tokens" return "Decoded text" return DummyTokenizer() def get_hybrid_attention_config(): """Get configuration for smart hybrid attention mechanism""" from utils.smartHybridAttention import get_hybrid_attention_config return get_hybrid_attention_config() def load_transformer_model(model_name: str, device: Optional[torch.device] = None) -> AutoModel: """ Load a transformer model. Args: model_name: Name of the model to load device: Optional device to load the model on Returns: Loaded transformer model """ try: logger.info(f"Loading transformer model: {model_name}") model = AutoModel.from_pretrained(model_name) if device: model = model.to(device) logger.info(f"Successfully loaded model: {model_name}") return model except Exception as e: logger.error(f"Error loading model {model_name}: {e}") raise def clear_cache(): """Clear all model and tokenizer caches to free memory.""" global _model_cache, _tokenizer_cache, _sentence_transformer_cache _model_cache.clear() _tokenizer_cache.clear() _sentence_transformer_cache.clear() logger.info("Cleared transformer model and tokenizer caches") def get_embedding(text: str, model: Optional[SentenceTransformer] = None) -> torch.Tensor: """Get embedding for a text string using a sentence transformer model.""" if model is None: model = get_sentence_transformer(DEFAULT_SENTENCE_TRANSFORMER) return model.encode(text, convert_to_tensor=True)