| | """
|
| | Unified utilities for handling transformers, tokenizers, and embeddings.
|
| | """
|
| | import os
|
| | import logging
|
| | import torch
|
| | from typing import Dict, Any, Optional, Union, List
|
| | from sentence_transformers import SentenceTransformer
|
| | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoModel
|
| |
|
| |
|
| | from utils.sentence_transformer_utils import get_sentence_transformer as load_sentence_transformer
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | DEFAULT_SENTENCE_TRANSFORMER = "sentence-transformers/Wildnerve-tlm01-0.05Bx12"
|
| | DEFAULT_TOKENIZER = "bert-base-uncased"
|
| | FALLBACK_TOKENIZERS = ["bert-base-uncased", "gpt2", "roberta-base"]
|
| |
|
| |
|
| | _model_cache = {}
|
| | _tokenizer_cache = {}
|
| | _sentence_transformer_cache = {}
|
| |
|
| | def get_sentence_transformer(model_name):
|
| | try:
|
| | from sentence_transformers import SentenceTransformer
|
| | return SentenceTransformer(model_name)
|
| | except Exception as e:
|
| | logging.error(f"Failed to load sentence transformer {model_name}: {e}")
|
| | logging.warning("Falling back to default model: Wildnerve-tlm01-0.05Bx12")
|
| | from sentence_transformers import SentenceTransformer
|
| | return SentenceTransformer("Wildnerve-tlm01-0.05Bx12")
|
| |
|
| | def get_tokenizer(model_name: str = "bert-base-uncased"):
|
| | """Get a tokenizer with proper error handling"""
|
| | try:
|
| | from transformers import AutoTokenizer
|
| | logger.info(f"Loading tokenizer: {model_name}")
|
| | return AutoTokenizer.from_pretrained(model_name)
|
| | except Exception as e:
|
| | logger.error(f"Failed to load tokenizer {model_name}: {e}")
|
| |
|
| | logger.warning("Using dummy tokenizer as fallback")
|
| |
|
| | class DummyTokenizer:
|
| | def __init__(self):
|
| | self.vocab_size = 30522
|
| | self.pad_token_id = 0
|
| | self.eos_token_id = 102
|
| | self.bos_token_id = 101
|
| |
|
| | def __call__(self, text, **kwargs):
|
| | """Convert text to a dict with dummy tensors"""
|
| | import torch
|
| |
|
| |
|
| | is_batch = isinstance(text, list)
|
| | texts = text if is_batch else [text]
|
| |
|
| |
|
| | input_ids = []
|
| | attention_mask = []
|
| |
|
| | for t in texts:
|
| |
|
| | import hashlib
|
| | hash_obj = hashlib.md5(t.encode())
|
| | seed = int(hash_obj.hexdigest(), 16) % 10000
|
| |
|
| | import random
|
| | random.seed(seed)
|
| |
|
| |
|
| | max_length = kwargs.get("max_length", 128)
|
| | length = min(len(t.split()), max_length)
|
| |
|
| |
|
| | ids = [self.bos_token_id] + [random.randint(1000, 30000) for _ in range(length-2)] + [self.eos_token_id]
|
| | mask = [1] * len(ids)
|
| |
|
| |
|
| | if "padding" in kwargs:
|
| | pad_length = max_length - len(ids)
|
| | if pad_length > 0:
|
| | ids.extend([self.pad_token_id] * pad_length)
|
| | mask.extend([0] * pad_length)
|
| |
|
| | input_ids.append(torch.tensor(ids))
|
| | attention_mask.append(torch.tensor(mask))
|
| |
|
| |
|
| | if "return_tensors" in kwargs and kwargs["return_tensors"] == "pt":
|
| | if is_batch or len(texts) > 1:
|
| | return {
|
| | "input_ids": torch.stack(input_ids),
|
| | "attention_mask": torch.stack(attention_mask)
|
| | }
|
| | else:
|
| | return {
|
| | "input_ids": input_ids[0].unsqueeze(0),
|
| | "attention_mask": attention_mask[0].unsqueeze(0)
|
| | }
|
| | else:
|
| | return {
|
| | "input_ids": input_ids[0] if not is_batch and len(texts) == 1 else input_ids,
|
| | "attention_mask": attention_mask[0] if not is_batch and len(texts) == 1 else attention_mask
|
| | }
|
| |
|
| | def decode(self, token_ids, skip_special_tokens=True, **kwargs):
|
| | """Convert token IDs back to text"""
|
| | if isinstance(token_ids, (list, tuple)) and len(token_ids) > 0:
|
| | return f"Decoded text from {len(token_ids)} tokens"
|
| | return "Decoded text"
|
| |
|
| | return DummyTokenizer()
|
| |
|
| | def get_hybrid_attention_config():
|
| | """Get configuration for smart hybrid attention mechanism"""
|
| | from utils.smartHybridAttention import get_hybrid_attention_config
|
| | return get_hybrid_attention_config()
|
| |
|
| | def load_transformer_model(model_name: str, device: Optional[torch.device] = None) -> AutoModel:
|
| | """
|
| | Load a transformer model.
|
| |
|
| | Args:
|
| | model_name: Name of the model to load
|
| | device: Optional device to load the model on
|
| |
|
| | Returns:
|
| | Loaded transformer model
|
| | """
|
| | try:
|
| | logger.info(f"Loading transformer model: {model_name}")
|
| | model = AutoModel.from_pretrained(model_name)
|
| |
|
| | if device:
|
| | model = model.to(device)
|
| |
|
| | logger.info(f"Successfully loaded model: {model_name}")
|
| | return model
|
| | except Exception as e:
|
| | logger.error(f"Error loading model {model_name}: {e}")
|
| | raise
|
| |
|
| | def clear_cache():
|
| | """Clear all model and tokenizer caches to free memory."""
|
| | global _model_cache, _tokenizer_cache, _sentence_transformer_cache
|
| | _model_cache.clear()
|
| | _tokenizer_cache.clear()
|
| | _sentence_transformer_cache.clear()
|
| | logger.info("Cleared transformer model and tokenizer caches")
|
| |
|
| | def get_embedding(text: str, model: Optional[SentenceTransformer] = None) -> torch.Tensor:
|
| | """Get embedding for a text string using a sentence transformer model."""
|
| | if model is None:
|
| | model = get_sentence_transformer(DEFAULT_SENTENCE_TRANSFORMER)
|
| | return model.encode(text, convert_to_tensor=True)
|
| |
|