Wildnerve-tlm01_Hybrid_Model / utils /transformer_utils.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
Unified utilities for handling transformers, tokenizers, and embeddings.
"""
import os
import logging
import torch
from typing import Dict, Any, Optional, Union, List
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoModel
# Import the new sentence transformer utilities
from utils.sentence_transformer_utils import get_sentence_transformer as load_sentence_transformer
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SENTENCE_TRANSFORMER = "sentence-transformers/Wildnerve-tlm01-0.05Bx12"
DEFAULT_TOKENIZER = "bert-base-uncased"
FALLBACK_TOKENIZERS = ["bert-base-uncased", "gpt2", "roberta-base"]
# Cache for loaded models to avoid reloading
_model_cache = {}
_tokenizer_cache = {}
_sentence_transformer_cache = {}
def get_sentence_transformer(model_name):
try:
from sentence_transformers import SentenceTransformer
return SentenceTransformer(model_name)
except Exception as e:
logging.error(f"Failed to load sentence transformer {model_name}: {e}")
logging.warning("Falling back to default model: Wildnerve-tlm01-0.05Bx12")
from sentence_transformers import SentenceTransformer
return SentenceTransformer("Wildnerve-tlm01-0.05Bx12")
def get_tokenizer(model_name: str = "bert-base-uncased"):
"""Get a tokenizer with proper error handling"""
try:
from transformers import AutoTokenizer
logger.info(f"Loading tokenizer: {model_name}")
return AutoTokenizer.from_pretrained(model_name)
except Exception as e:
logger.error(f"Failed to load tokenizer {model_name}: {e}")
# Return a minimal dummy tokenizer that won't break everything
logger.warning("Using dummy tokenizer as fallback")
class DummyTokenizer:
def __init__(self):
self.vocab_size = 30522 # BERT vocab size
self.pad_token_id = 0
self.eos_token_id = 102
self.bos_token_id = 101
def __call__(self, text, **kwargs):
"""Convert text to a dict with dummy tensors"""
import torch
# Handle batch vs single input
is_batch = isinstance(text, list)
texts = text if is_batch else [text]
# Create random but deterministic IDs based on text length
input_ids = []
attention_mask = []
for t in texts:
# Use text length to create deterministic pseudo-random sequence
import hashlib
hash_obj = hashlib.md5(t.encode())
seed = int(hash_obj.hexdigest(), 16) % 10000
import random
random.seed(seed)
# Get length or use max_length if provided
max_length = kwargs.get("max_length", 128)
length = min(len(t.split()), max_length)
# Generate ids and mask
ids = [self.bos_token_id] + [random.randint(1000, 30000) for _ in range(length-2)] + [self.eos_token_id]
mask = [1] * len(ids)
# Pad if needed
if "padding" in kwargs:
pad_length = max_length - len(ids)
if pad_length > 0:
ids.extend([self.pad_token_id] * pad_length)
mask.extend([0] * pad_length)
input_ids.append(torch.tensor(ids))
attention_mask.append(torch.tensor(mask))
# Stack tensors
if "return_tensors" in kwargs and kwargs["return_tensors"] == "pt":
if is_batch or len(texts) > 1:
return {
"input_ids": torch.stack(input_ids),
"attention_mask": torch.stack(attention_mask)
}
else:
return {
"input_ids": input_ids[0].unsqueeze(0),
"attention_mask": attention_mask[0].unsqueeze(0)
}
else:
return {
"input_ids": input_ids[0] if not is_batch and len(texts) == 1 else input_ids,
"attention_mask": attention_mask[0] if not is_batch and len(texts) == 1 else attention_mask
}
def decode(self, token_ids, skip_special_tokens=True, **kwargs):
"""Convert token IDs back to text"""
if isinstance(token_ids, (list, tuple)) and len(token_ids) > 0:
return f"Decoded text from {len(token_ids)} tokens"
return "Decoded text"
return DummyTokenizer()
def get_hybrid_attention_config():
"""Get configuration for smart hybrid attention mechanism"""
from utils.smartHybridAttention import get_hybrid_attention_config
return get_hybrid_attention_config()
def load_transformer_model(model_name: str, device: Optional[torch.device] = None) -> AutoModel:
"""
Load a transformer model.
Args:
model_name: Name of the model to load
device: Optional device to load the model on
Returns:
Loaded transformer model
"""
try:
logger.info(f"Loading transformer model: {model_name}")
model = AutoModel.from_pretrained(model_name)
if device:
model = model.to(device)
logger.info(f"Successfully loaded model: {model_name}")
return model
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
raise
def clear_cache():
"""Clear all model and tokenizer caches to free memory."""
global _model_cache, _tokenizer_cache, _sentence_transformer_cache
_model_cache.clear()
_tokenizer_cache.clear()
_sentence_transformer_cache.clear()
logger.info("Cleared transformer model and tokenizer caches")
def get_embedding(text: str, model: Optional[SentenceTransformer] = None) -> torch.Tensor:
"""Get embedding for a text string using a sentence transformer model."""
if model is None:
model = get_sentence_transformer(DEFAULT_SENTENCE_TRANSFORMER)
return model.encode(text, convert_to_tensor=True)