|
|
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import math
|
|
|
import torch
|
|
|
import logging
|
|
|
import traceback
|
|
|
import numpy as np
|
|
|
import transformers
|
|
|
import torch.nn as nn
|
|
|
from typing import Optional, List, Dict, Union, Tuple
|
|
|
|
|
|
from codecarbon import EmissionsTracker
|
|
|
import transformer_patches
|
|
|
|
|
|
from service_registry import registry, MODEL, TOKENIZER
|
|
|
from utils.transformer_utils import get_tokenizer
|
|
|
from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
|
|
|
from base_interfaces.common_types import *
|
|
|
from base_interfaces.model_interface import AbstractModel
|
|
|
|
|
|
from config import app_config
|
|
|
|
|
|
import json
|
|
|
from types import SimpleNamespace
|
|
|
|
|
|
|
|
|
try:
|
|
|
from transformers.modeling_outputs import ModelOutput
|
|
|
except ImportError:
|
|
|
|
|
|
class ModelOutput:
|
|
|
"""Minimal placeholder for transformers ModelOutput class"""
|
|
|
def __init__(self, **kwargs):
|
|
|
for k, v in kwargs.items():
|
|
|
setattr(self, k, v)
|
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
class WildnerveConfig(PretrainedConfig):
|
|
|
model_type = "wildnerve_tlm01"
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size: int = 50257,
|
|
|
embedding_dim: int = 768,
|
|
|
num_heads: int = 12,
|
|
|
hidden_dim: int = 768,
|
|
|
num_layers: int = 12,
|
|
|
output_size: int = 50257,
|
|
|
dropout: float = 0.1,
|
|
|
max_seq_length: int = 767,
|
|
|
pooling_mode: str = "last",
|
|
|
model_name: str = "gpt2",
|
|
|
specialization: str = "general",
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.vocab_size = vocab_size
|
|
|
self.embedding_dim = embedding_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_layers = num_layers
|
|
|
self.output_size = output_size
|
|
|
self.dropout = dropout
|
|
|
self.max_seq_length = max_seq_length
|
|
|
self.pooling_mode = pooling_mode
|
|
|
self.model_name = model_name
|
|
|
self.specialization = specialization
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
if hasattr(transformers, 'integrations') and hasattr(transformers.integrations, 'CodeCarbonCallback'):
|
|
|
logger.info("transformers.integrations.CodeCarbonCallback is available")
|
|
|
|
|
|
|
|
|
if hasattr(transformers.integrations, 'CodeCarbonCallback'):
|
|
|
callback_module = transformers.integrations.CodeCarbonCallback.__module__
|
|
|
if (callback_module == 'carbon_tracking'):
|
|
|
logger.info("Using our clean architecture implementation for CodeCarbonCallback")
|
|
|
else:
|
|
|
logger.info(f"Using original implementation for CodeCarbonCallback from {callback_module}")
|
|
|
|
|
|
|
|
|
for d in (app_config.DATA_DIR, app_config.MODEL_DIR):
|
|
|
try: os.makedirs(d, exist_ok=True)
|
|
|
except Exception as _e: logger.warning(f"Could not create directory {d}: {_e}")
|
|
|
|
|
|
|
|
|
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
|
|
|
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
def __init__(self, d_model: int, max_len: Optional[int] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
if max_len is None:
|
|
|
cfg = app_config.TRANSFORMER_CONFIG
|
|
|
if isinstance(cfg, dict):
|
|
|
max_len = cfg.get("MAX_SEQ_LENGTH", 512)
|
|
|
else:
|
|
|
max_len = getattr(cfg, "MAX_SEQ_LENGTH", 512)
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
pe = pe.unsqueeze(1)
|
|
|
self.register_buffer("pe", pe)
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
seq_len = x.size(0)
|
|
|
x = x + self.pe[:seq_len]
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Wildnerve_tlm01(nn.Module, AbstractModel):
|
|
|
"""A Transformer-based Tiny Language Model that uses:
|
|
|
- A custom built encoder & decoder (embedding, positional encoding, and TransformerEncoder)
|
|
|
- An adapter and classifier for post-processing
|
|
|
- The AutoTokenizer for consistent tokenization and decoding
|
|
|
- SmartHybridAttention for better context handling"""
|
|
|
|
|
|
|
|
|
VALID_SPECIALIZATIONS = [
|
|
|
"python",
|
|
|
"rust",
|
|
|
"solidity",
|
|
|
"computer",
|
|
|
"cpp",
|
|
|
"go",
|
|
|
"java",
|
|
|
"javascript",
|
|
|
"mathematics",
|
|
|
"nim",
|
|
|
"other_information",
|
|
|
"physics",
|
|
|
"general"
|
|
|
]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size: int = 50257,
|
|
|
specialization: str = "general",
|
|
|
dataset_path: str = None,
|
|
|
model_name: str = "gpt2",
|
|
|
embedding_dim: int = 768,
|
|
|
num_heads: int = 12,
|
|
|
hidden_dim: int = 768,
|
|
|
num_layers: int = 12,
|
|
|
output_size: int = 50257,
|
|
|
dropout: float = 0.1,
|
|
|
max_seq_length: int = 767,
|
|
|
pooling_mode: str = "last",
|
|
|
tokenizer=None,
|
|
|
max_length: Optional[int] = None
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
if specialization not in self.VALID_SPECIALIZATIONS:
|
|
|
logger.warning(f"Unknown specialization '{specialization}'. Valid options are: {', '.join(self.VALID_SPECIALIZATIONS)}")
|
|
|
logger.warning(f"Defaulting to 'general' specialization")
|
|
|
specialization = "general"
|
|
|
|
|
|
|
|
|
object.__setattr__(self, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
|
|
self.specialization = specialization
|
|
|
self.dataset_path = dataset_path
|
|
|
self.model_name = model_name
|
|
|
self.pooling_mode = pooling_mode
|
|
|
self.embedding_dim = embedding_dim
|
|
|
self.vocab_size = vocab_size
|
|
|
self.max_seq_length = max_seq_length
|
|
|
self.num_heads = num_heads
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_layers = num_layers
|
|
|
self.output_size = output_size
|
|
|
self.dropout = dropout
|
|
|
|
|
|
|
|
|
self.model_last_used = {}
|
|
|
|
|
|
|
|
|
if tokenizer is not None:
|
|
|
self.tokenizer = tokenizer
|
|
|
else:
|
|
|
if registry.has(TOKENIZER):
|
|
|
self.tokenizer = registry.get(TOKENIZER)
|
|
|
else:
|
|
|
try:
|
|
|
from transformers import GPT2Tokenizer
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
|
logger.info("Initialized GPT2Tokenizer")
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"Failed to load GPT-2 tokenizer: {e}")
|
|
|
logger.warning(f"Error details: {traceback.format_exc()}")
|
|
|
|
|
|
|
|
|
retry_count = 0
|
|
|
max_retries = 5
|
|
|
success = False
|
|
|
|
|
|
while not success and retry_count < max_retries:
|
|
|
retry_count += 1
|
|
|
delay = 2 ** retry_count
|
|
|
logger.info(f"Retrying tokenizer initialization (attempt {retry_count}/{max_retries}) after {delay}s delay")
|
|
|
|
|
|
try:
|
|
|
import time
|
|
|
time.sleep(delay)
|
|
|
from transformers import GPT2Tokenizer
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
|
|
success = True
|
|
|
logger.info(f"Successfully loaded GPT-2 tokenizer on retry {retry_count}")
|
|
|
except Exception as retry_e:
|
|
|
logger.warning(f"Retry {retry_count} failed: {retry_e}")
|
|
|
|
|
|
|
|
|
if not success:
|
|
|
logger.error("All tokenizer initialization attempts failed")
|
|
|
from utils.transformer_utils import get_tokenizer
|
|
|
self.tokenizer = get_tokenizer(model_name="gpt2")
|
|
|
logger.warning("Using simplified tokenizer wrapper as fallback")
|
|
|
|
|
|
registry.register(TOKENIZER, self.tokenizer, overwrite=True)
|
|
|
|
|
|
|
|
|
model_registry_key = f"model_{specialization}"
|
|
|
registry.register(model_registry_key, self)
|
|
|
|
|
|
|
|
|
if specialization == "general":
|
|
|
registry.register(MODEL, self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
|
self.pos_encoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tgt_embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
|
self.pos_decoder = PositionalEncoding(embedding_dim, max_len=max_seq_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
d_model=embedding_dim,
|
|
|
nhead=num_heads,
|
|
|
dim_feedforward=hidden_dim,
|
|
|
dropout=dropout,
|
|
|
batch_first=True
|
|
|
)
|
|
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
|
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer(
|
|
|
d_model=embedding_dim,
|
|
|
nhead=num_heads,
|
|
|
dim_feedforward=hidden_dim,
|
|
|
dropout=dropout,
|
|
|
batch_first=True
|
|
|
)
|
|
|
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
|
|
|
|
|
|
|
|
attention_config = get_hybrid_attention_config()
|
|
|
attention_config['NUM_HEADS'] = num_heads
|
|
|
attention_config['WINDOW_SIZE'] = max(256, max_seq_length // 4)
|
|
|
self.hybrid_attention = SmartHybridAttention(attention_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.adapter = nn.Sequential(
|
|
|
nn.Linear(embedding_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(hidden_dim, embedding_dim)
|
|
|
)
|
|
|
self.classifier = nn.Linear(embedding_dim, self.vocab_size)
|
|
|
self.dropout_layer = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
self.final_layer = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
|
|
|
self.config = WildnerveConfig(
|
|
|
vocab_size=self.vocab_size,
|
|
|
embedding_dim=self.embedding_dim,
|
|
|
num_heads=self.num_heads,
|
|
|
hidden_dim=self.hidden_dim,
|
|
|
num_layers=self.num_layers,
|
|
|
output_size=self.output_size,
|
|
|
dropout=self.dropout,
|
|
|
max_seq_length=self.max_seq_length,
|
|
|
pooling_mode=self.pooling_mode,
|
|
|
model_name=self.model_name,
|
|
|
specialization=self.specialization
|
|
|
)
|
|
|
|
|
|
def init_weights(self) -> None:
|
|
|
initrange = 0.1
|
|
|
with torch.no_grad():
|
|
|
self.embedding.weight.uniform_(-initrange, initrange)
|
|
|
self.tgt_embedding.weight.uniform_(-initrange, initrange)
|
|
|
self.classifier.weight.uniform_(-initrange, initrange)
|
|
|
self.classifier.bias.zero_()
|
|
|
for layer in self.adapter:
|
|
|
if isinstance(layer, nn.Linear):
|
|
|
layer.weight.uniform_(-initrange, initrange)
|
|
|
if layer.bias is not None:
|
|
|
layer.bias.zero_()
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids=None,
|
|
|
attention_mask=None,
|
|
|
labels=None,
|
|
|
src=None,
|
|
|
tgt=None,
|
|
|
src_mask: Optional[torch.Tensor] = None,
|
|
|
src_key_padding_mask=None,
|
|
|
tgt_key_padding_mask=None,
|
|
|
memory_key_padding_mask=None,
|
|
|
return_sequence=False,
|
|
|
**kwargs
|
|
|
):
|
|
|
try:
|
|
|
|
|
|
logger.info(f"Input shapes - src: {src.shape if src is not None else None}, tgt: {tgt.shape if tgt is not None else None}")
|
|
|
|
|
|
|
|
|
if input_ids is not None:
|
|
|
src = input_ids
|
|
|
|
|
|
|
|
|
|
|
|
src_embeddings = self.embedding(src)
|
|
|
|
|
|
|
|
|
src_embeddings = self.pos_encoder(src_embeddings)
|
|
|
|
|
|
|
|
|
memory = self.transformer_encoder(src_embeddings,
|
|
|
src_key_padding_mask=src_key_padding_mask)
|
|
|
|
|
|
if src.size(1) > 256 and hasattr(self, 'hybrid_attention'):
|
|
|
|
|
|
query = src_embeddings.transpose(0, 1)
|
|
|
key = query
|
|
|
value = query
|
|
|
|
|
|
|
|
|
if src_mask is None and src is not None:
|
|
|
|
|
|
src_seq_len = src.size(1)
|
|
|
src_mask = torch.zeros((src_seq_len, src_seq_len), device=src.device, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
hybrid_outputs = self.hybrid_attention(
|
|
|
query=query,
|
|
|
key=key,
|
|
|
value=value,
|
|
|
key_padding_mask=src_key_padding_mask,
|
|
|
attn_mask=src_mask,
|
|
|
prompt_length=src.size(1),
|
|
|
prompt_complexity=0.5
|
|
|
)
|
|
|
|
|
|
|
|
|
encoded_src = hybrid_outputs
|
|
|
|
|
|
|
|
|
if tgt is not None:
|
|
|
tgt_embeddings = self.tgt_embedding(tgt)
|
|
|
tgt_embeddings = self.pos_decoder(tgt_embeddings)
|
|
|
output = self.transformer_decoder(tgt_embeddings, memory,
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
memory_key_padding_mask=memory_key_padding_mask)
|
|
|
else:
|
|
|
output = memory
|
|
|
|
|
|
|
|
|
if output.dim() == 2:
|
|
|
output = output.unsqueeze(1)
|
|
|
|
|
|
|
|
|
logits = self.final_layer(output)
|
|
|
|
|
|
|
|
|
if logits.dim() == 2:
|
|
|
|
|
|
batch_size, vocab_size = logits.shape
|
|
|
logger.info(f"2D tensor: batch_size={batch_size}, vocab_size={vocab_size}")
|
|
|
logits = logits.unsqueeze(1)
|
|
|
logger.info(f"Reshaped 2D output to 3D tensor: {logits.shape}")
|
|
|
|
|
|
|
|
|
logger.info(f"Output shape: {logits.shape}, dimensions: {logits.dim()}")
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
|
|
|
if labels.dim() > 1:
|
|
|
labels = labels.reshape(-1)
|
|
|
logger.info(f"Reshaped labels to {labels.shape}")
|
|
|
|
|
|
|
|
|
batch_size, seq_length, vocab_size = logits.shape
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
loss = loss_fct(logits.reshape(-1, vocab_size), labels)
|
|
|
logger.info(f"Returning loss tensor: {loss.item()}")
|
|
|
|
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
past_key_values=None,
|
|
|
hidden_states=None,
|
|
|
attentions=None,
|
|
|
cross_attentions=None
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in forward pass: {str(e)}")
|
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
|
|
|
|
|
|
|
logger.error(f"Input shapes - src: {src.shape if src is not None else None}, input_ids: {input_ids.shape if input_ids is not None else None}")
|
|
|
|
|
|
|
|
|
dummy_batch = 1
|
|
|
if src is not None:
|
|
|
dummy_batch = src.shape[0]
|
|
|
elif input_ids is not None:
|
|
|
dummy_batch = input_ids.shape[0]
|
|
|
|
|
|
|
|
|
dummy_output = torch.zeros((dummy_batch, 1, self.vocab_size), device=next(self.parameters()).device)
|
|
|
dummy_loss = torch.tensor(float('nan'), device=next(self.parameters()).device)
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
|
loss=dummy_loss,
|
|
|
logits=dummy_output,
|
|
|
past_key_values=None,
|
|
|
hidden_states=None,
|
|
|
attentions=None,
|
|
|
cross_attentions=None
|
|
|
)
|
|
|
|
|
|
|
|
|
def encode_sentences(self, sentences, batch_size=32, normalize_embeddings=True):
|
|
|
"""Encode sentences into vectors (sentence transformer functionality)"""
|
|
|
self.eval()
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
|
|
|
if isinstance(sentences, str):
|
|
|
sentences = [sentences]
|
|
|
|
|
|
class SentencesDataset(Dataset):
|
|
|
def __init__(self, sentences, tokenizer, max_length):
|
|
|
self.sentences = sentences
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_length = max_length
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.sentences)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return self.tokenizer(self.sentences[idx],
|
|
|
padding='max_length',
|
|
|
truncation=True,
|
|
|
max_length=self.max_length,
|
|
|
return_tensors='pt')
|
|
|
|
|
|
|
|
|
dataset = SentencesDataset(sentences, self.tokenizer, self.max_seq_length)
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size)
|
|
|
|
|
|
all_embeddings = []
|
|
|
device = next(self.parameters()).device
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in dataloader:
|
|
|
inputs = {k: v.squeeze(1).to(device) for k, v in batch.items()}
|
|
|
outputs = self(inputs['input_ids'], src_key_padding_mask=inputs.get('attention_mask'))
|
|
|
|
|
|
if normalize_embeddings:
|
|
|
outputs = torch.nn.functional.normalize(outputs, p=2, dim=1)
|
|
|
|
|
|
all_embeddings.append(outputs.cpu().numpy())
|
|
|
|
|
|
return np.vstack(all_embeddings)
|
|
|
|
|
|
def similarity(self, sentence1: str, sentence2: str) -> float:
|
|
|
"""Compute cosine similarity between two sentences"""
|
|
|
embeddings = self.encode_sentences([sentence1, sentence2])
|
|
|
return np.dot(embeddings[0], embeddings[1]) / (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
self,
|
|
|
prompt=None,
|
|
|
input_ids=None,
|
|
|
max_length: int = None,
|
|
|
device: str = None,
|
|
|
temperature: float = 0.7,
|
|
|
**kwargs
|
|
|
) -> str:
|
|
|
"""Generate text using the model, supporting either prompt string or input_ids."""
|
|
|
|
|
|
adapter_layer = registry.get("adapter_layer")
|
|
|
if (adapter_layer and hasattr(adapter_layer, "generate")):
|
|
|
if prompt:
|
|
|
return adapter_layer.generate(prompt, max_length=max_length, temperature=temperature, **kwargs)
|
|
|
elif input_ids is not None and self.tokenizer:
|
|
|
|
|
|
decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
|
|
return adapter_layer.generate(decoded_prompt, max_length=max_length, temperature=temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Generate called with: prompt={type(prompt).__name__ if prompt else None}, input_ids={type(input_ids).__name__ if input_ids else None}")
|
|
|
|
|
|
|
|
|
if max_length is None:
|
|
|
if hasattr(self, 'max_seq_length'):
|
|
|
max_length = self.max_seq_length
|
|
|
else:
|
|
|
max_length = 512
|
|
|
|
|
|
|
|
|
if device is None:
|
|
|
device = next(self.parameters()).device
|
|
|
|
|
|
|
|
|
if isinstance(prompt, str) and prompt:
|
|
|
if not self.tokenizer:
|
|
|
raise ValueError("Tokenizer not available but prompt is a string")
|
|
|
|
|
|
|
|
|
if callable(self.tokenizer):
|
|
|
inputs = self.tokenizer(
|
|
|
prompt,
|
|
|
return_tensors="pt",
|
|
|
truncation=True,
|
|
|
padding=True
|
|
|
)
|
|
|
input_ids = inputs.input_ids.to(device)
|
|
|
logger.debug(f"Tokenized prompt '{prompt[:20]}...' to tensor of shape {input_ids.shape}")
|
|
|
|
|
|
|
|
|
if input_ids is None:
|
|
|
raise ValueError("Either prompt or input_ids must be provided")
|
|
|
|
|
|
|
|
|
if not isinstance(input_ids, torch.Tensor):
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
|
if input_ids.dim() == 1:
|
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
|
|
|
|
|
gen_kwargs = {}
|
|
|
gen_kwargs.update(kwargs)
|
|
|
|
|
|
|
|
|
if 'max_length' not in gen_kwargs and 'max_new_tokens' not in gen_kwargs:
|
|
|
|
|
|
if input_ids.shape[1] > max_length - 50:
|
|
|
gen_kwargs['max_new_tokens'] = 100
|
|
|
else:
|
|
|
gen_kwargs['max_length'] = max_length
|
|
|
|
|
|
if 'temperature' not in gen_kwargs:
|
|
|
gen_kwargs['temperature'] = temperature
|
|
|
|
|
|
|
|
|
if 'max_length' in gen_kwargs and input_ids.shape[1] > (gen_kwargs['max_length'] - 50):
|
|
|
logger.info(f"Input length {input_ids.shape[1]} close to max_length, switching to max_new_tokens")
|
|
|
gen_kwargs['max_new_tokens'] = 100
|
|
|
del gen_kwargs['max_length']
|
|
|
|
|
|
try:
|
|
|
|
|
|
output_ids = self.generate_tokens(input_ids, **gen_kwargs)
|
|
|
|
|
|
|
|
|
if self.tokenizer:
|
|
|
|
|
|
input_length = input_ids.shape[1]
|
|
|
if hasattr(output_ids, 'shape') and len(output_ids.shape) > 1 and output_ids.shape[1] > input_length:
|
|
|
response_ids = output_ids[0][input_length:]
|
|
|
generated_text = self.tokenizer.decode(response_ids, skip_special_tokens=True)
|
|
|
else:
|
|
|
|
|
|
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
return generated_text
|
|
|
else:
|
|
|
|
|
|
return f"Generated token IDs: {output_ids}"
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in generate: {e}", exc_info=True)
|
|
|
return f"Error generating response: {str(e)}"
|
|
|
|
|
|
def generate_tokens(self, input_ids, max_length=None, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0, **kwargs):
|
|
|
"""Generate tokens autoregressively."""
|
|
|
logger.info(f"generate_tokens called with tensor of shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}")
|
|
|
|
|
|
try:
|
|
|
import torch
|
|
|
|
|
|
|
|
|
if not isinstance(input_ids, torch.Tensor):
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
|
|
|
|
|
|
|
if input_ids.dim() == 1:
|
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
|
|
|
|
|
if max_length is None:
|
|
|
max_length = min(getattr(self, 'max_seq_length', 1024), 1024)
|
|
|
max_length = min(max_length, 1024)
|
|
|
|
|
|
|
|
|
generated_sequences = input_ids.clone()
|
|
|
|
|
|
|
|
|
for step in range(max_length - input_ids.shape[1]):
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self(generated_sequences)
|
|
|
|
|
|
|
|
|
if outputs.dim() == 3:
|
|
|
|
|
|
next_token_logits = outputs[:, -1, :]
|
|
|
else:
|
|
|
|
|
|
next_token_logits = outputs
|
|
|
|
|
|
|
|
|
if temperature > 0:
|
|
|
next_token_logits = next_token_logits / temperature
|
|
|
|
|
|
|
|
|
if top_k > 0:
|
|
|
top_k_values, top_k_indices = torch.topk(next_token_logits, top_k)
|
|
|
next_token_logits = torch.full_like(next_token_logits, float("-inf"))
|
|
|
for batch_idx in range(generated_sequences.shape[0]):
|
|
|
next_token_logits[batch_idx, top_k_indices[batch_idx]] = top_k_values[batch_idx]
|
|
|
|
|
|
|
|
|
probs = torch.softmax(next_token_logits, dim=-1)
|
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
|
|
|
|
|
|
|
|
generated_sequences = torch.cat([generated_sequences, next_tokens.unsqueeze(-1)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
return generated_sequences
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in generate_tokens: {e}")
|
|
|
return input_ids
|
|
|
|
|
|
def generate_with_decoding(self, input_ids=None, prompt=None, **kwargs):
|
|
|
"""
|
|
|
Generate text from either input_ids or a text prompt.
|
|
|
This is a helper method that handles both tokenization and decoding.
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
if prompt is not None and input_ids is None:
|
|
|
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
|
|
logger.error("No tokenizer available for text prompt")
|
|
|
return "Error: No tokenizer available for processing text prompt"
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
|
|
|
input_ids = inputs.input_ids
|
|
|
|
|
|
if input_ids is None:
|
|
|
logger.error("Neither prompt nor input_ids provided")
|
|
|
return "Error: No input provided"
|
|
|
|
|
|
|
|
|
output_ids = self.generate_tokens(input_ids, **kwargs)
|
|
|
|
|
|
|
|
|
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
|
|
|
return f"Generated sequence (no tokenizer): {output_ids.tolist()}"
|
|
|
|
|
|
return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in generate_with_decoding: {e}", exc_info=True)
|
|
|
return f"Error generating text: {str(e)}"
|
|
|
|
|
|
def forward_with_custom_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward pass that accepts pre-calculated embeddings to bypass shape errors."""
|
|
|
try:
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
|
embeddings = embeddings.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
batch_first = getattr(self.transformer_encoder, 'batch_first', False)
|
|
|
|
|
|
if batch_first and embeddings.shape[0] <= embeddings.shape[1]:
|
|
|
|
|
|
|
|
|
embeddings = embeddings.transpose(0, 1)
|
|
|
|
|
|
|
|
|
if hasattr(self, 'pos_encoder'):
|
|
|
|
|
|
if not batch_first:
|
|
|
|
|
|
if embeddings.shape[0] > embeddings.shape[1]:
|
|
|
|
|
|
embeddings = self.pos_encoder(embeddings)
|
|
|
else:
|
|
|
|
|
|
embeddings = embeddings.transpose(0, 1)
|
|
|
embeddings = self.pos_encoder(embeddings)
|
|
|
embeddings = embeddings.transpose(0, 1)
|
|
|
else:
|
|
|
|
|
|
embeddings = self.pos_encoder(embeddings)
|
|
|
|
|
|
|
|
|
encoded = self.transformer_encoder(embeddings)
|
|
|
|
|
|
|
|
|
if hasattr(self, 'adapter'):
|
|
|
encoded = self.adapter(encoded)
|
|
|
|
|
|
|
|
|
if self.pooling_mode == "mean":
|
|
|
pooled = encoded.mean(dim=1)
|
|
|
elif self.pooling_mode == "max":
|
|
|
pooled = torch.max(encoded, dim=1)[0]
|
|
|
elif self.pooling_mode == "cls":
|
|
|
|
|
|
pooled = encoded[:, 0]
|
|
|
else:
|
|
|
pooled = encoded.mean(dim=1)
|
|
|
|
|
|
|
|
|
pooled = self.dropout_layer(pooled)
|
|
|
output = self.classifier(pooled)
|
|
|
|
|
|
return output
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in custom embeddings forward pass: {e}")
|
|
|
|
|
|
return torch.zeros(1, self.output_size, device=device)
|
|
|
|
|
|
def forward_with_error_handling(
|
|
|
self,
|
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
|
**kwargs
|
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
"""Forward pass with enhanced error handling for shape mismatches"""
|
|
|
try:
|
|
|
|
|
|
return self.forward(
|
|
|
src=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
token_type_ids=token_type_ids,
|
|
|
**kwargs
|
|
|
)
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
if "shape" in str(e):
|
|
|
logger.warning(f"Shape mismatch detected: {e}")
|
|
|
if input_ids.dim() == 3 and input_ids.size(0) > input_ids.size(1):
|
|
|
input_ids = input_ids.transpose(0, 1)
|
|
|
|
|
|
try:
|
|
|
embedded = self.embedding(input_ids)
|
|
|
if hasattr(self, 'pos_encoder'):
|
|
|
embedded = self.pos_encoder(embedded)
|
|
|
encoder_out = self.transformer_encoder(embedded)
|
|
|
pooled = encoder_out.mean(dim=1)
|
|
|
pooled = self.dropout_layer(pooled)
|
|
|
return self.classifier(pooled)
|
|
|
except Exception as inner_e:
|
|
|
logger.error(f"Adaptation failed: {inner_e}")
|
|
|
batch_size = input_ids.size(0) if input_ids is not None else 1
|
|
|
return torch.zeros((batch_size, self.output_size), device=self.device)
|
|
|
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
logger.error(f"Unhandled error in forward_with_error_handling: {e}")
|
|
|
batch_size = input_ids.size(0) if input_ids is not None else 1
|
|
|
return torch.zeros((batch_size, self.output_size), device=self.device)
|
|
|
|
|
|
def train_with_emissions_tracking(self, dataloader, optimizer, criterion, num_epochs=1):
|
|
|
"""
|
|
|
Train the model while tracking carbon emissions using CodeCarbon.
|
|
|
"""
|
|
|
tracker = EmissionsTracker()
|
|
|
tracker.start()
|
|
|
|
|
|
self.train()
|
|
|
for epoch in range(num_epochs):
|
|
|
for batch in dataloader:
|
|
|
inputs, labels = batch
|
|
|
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
outputs = self(inputs)
|
|
|
loss = criterion(outputs, labels)
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
logging.info(f"Epoch {epoch + 1}/{num_epochs} completed.")
|
|
|
|
|
|
emissions = tracker.stop()
|
|
|
logging.info(f"Training completed. Carbon emissions: {emissions:.4f} kg CO2")
|
|
|
|
|
|
def infer_with_emissions_tracking(self, input_ids):
|
|
|
"""
|
|
|
Perform inference while tracking carbon emissions using CodeCarbon.
|
|
|
"""
|
|
|
tracker = EmissionsTracker()
|
|
|
tracker.start()
|
|
|
|
|
|
self.eval()
|
|
|
with torch.no_grad():
|
|
|
outputs = self(input_ids)
|
|
|
|
|
|
emissions = tracker.stop()
|
|
|
logging.info(f"Inference completed. Carbon emissions: {emissions:.4f} kg CO2")
|
|
|
return outputs
|
|
|
|
|
|
def __call__(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
|
|
"""Forward pass with HF-style parameters"""
|
|
|
try:
|
|
|
return self.forward(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
labels=labels,
|
|
|
**kwargs
|
|
|
)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in __call__: {e}")
|
|
|
|
|
|
batch_size = input_ids.shape[0] if hasattr(input_ids, 'shape') else 1
|
|
|
vocab_size = self.vocab_size if hasattr(self, 'vocab_size') else 50257
|
|
|
device = input_ids.device if hasattr(input_ids, 'device') else 'cpu'
|
|
|
|
|
|
|
|
|
if labels is not None:
|
|
|
logger.error(f"Input shapes - input_ids: {input_ids.shape if hasattr(input_ids, 'shape') else 'unknown'}, "
|
|
|
f"labels: {labels.shape if hasattr(labels, 'shape') else 'unknown'}")
|
|
|
|
|
|
|
|
|
dummy_output = torch.zeros((batch_size, vocab_size), device=device)
|
|
|
class SimpleOutput:
|
|
|
def __init__(self, logits):
|
|
|
self.logits = logits
|
|
|
return SimpleOutput(dummy_output)
|
|
|
|
|
|
def save_pretrained(self, save_directory: str):
|
|
|
"""Save model weights in HF format."""
|
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
|
pt_file = os.path.join(save_directory, "pytorch_model.bin")
|
|
|
torch.save(self.state_dict(), pt_file)
|
|
|
logger.info(f"Saved model weights to {pt_file}")
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
|
|
|
"""
|
|
|
Instantiate model then load weights.
|
|
|
Accepts either a folder (containing pytorch_model.bin)
|
|
|
or a direct path to a .bin file.
|
|
|
"""
|
|
|
model = cls(*args, **kwargs)
|
|
|
if os.path.isdir(pretrained_model_name_or_path):
|
|
|
weight_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
|
|
else:
|
|
|
weight_path = pretrained_model_name_or_path
|
|
|
state = torch.load(weight_path, map_location=model.device)
|
|
|
model.load_state_dict(state, strict=False)
|
|
|
logger.info(f"Loaded weights from {weight_path}")
|
|
|
return model
|
|
|
|
|
|
|
|
|
registry.register("model_class_custom", Wildnerve_tlm01)
|
|
|
|
|
|
|
|
|
def initialize_tokenizer():
|
|
|
"""
|
|
|
Fallback function to initialize the tokenizer.
|
|
|
Tries up to 5 times and logs debug messages on each attempt.
|
|
|
"""
|
|
|
from transformers import GPT2Tokenizer, AutoTokenizer
|
|
|
max_attempts = 5
|
|
|
for attempt in range(1, max_attempts + 1):
|
|
|
try:
|
|
|
|
|
|
from service_registry import registry, TOKENIZER
|
|
|
if registry.has(TOKENIZER):
|
|
|
tokenizer = registry.get(TOKENIZER)
|
|
|
if tokenizer is not None:
|
|
|
logger.debug(f"Attempt {attempt}: Successfully retrieved tokenizer from registry.")
|
|
|
return tokenizer
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
logger.debug(f"Attempt {attempt}: Successfully loaded GPT-2 tokenizer.")
|
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
logger.debug("Set pad_token to eos_token for GPT-2 tokenizer")
|
|
|
|
|
|
|
|
|
registry.register(TOKENIZER, tokenizer)
|
|
|
return tokenizer
|
|
|
except Exception as e:
|
|
|
logger.debug(f"Attempt {attempt}: Failed to initialize tokenizer due to: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
logger.error("Tokenizer initialization failed after 5 attempts. Using fallback GPT2Tokenizer.")
|
|
|
try:
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
return tokenizer
|
|
|
except Exception as e:
|
|
|
logger.error(f"Default tokenizer initialization failed: {e}")
|
|
|
return None |