import torch import torch.nn as nn from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler from datasets import load_dataset, Dataset from huggingface_hub import HfApi from tqdm.auto import tqdm import numpy as np from typing import Dict, List, Tuple, Any def load_embedding_model(model_name_or_path: str) -> Tuple[nn.Module, AutoTokenizer]: """ Load an embedding model from Hugging Face. Args: model_name_or_path: The model name or path on Hugging Face Returns: Tuple of (model, tokenizer) """ try: # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) # Add padding token if not present if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model model = AutoModel.from_pretrained(model_name_or_path) return model, tokenizer except Exception as e: raise Exception(f"Failed to load model {model_name_or_path}: {str(e)}") def load_huggingface_dataset(dataset_name: str, split: str = "train") -> Dataset: """ Load a dataset from Hugging Face. Args: dataset_name: The dataset name on Hugging Face split: The dataset split to load Returns: The loaded dataset """ try: dataset = load_dataset(dataset_name, split=split) return dataset except Exception as e: raise Exception(f"Failed to load dataset {dataset_name}: {str(e)}") def prepare_dataset_for_training(dataset: Dataset, tokenizer: AutoTokenizer, text_column: str = None, max_length: int = 512) -> Dataset: """ Prepare dataset for embedding training by tokenizing texts. Args: dataset: The dataset to prepare tokenizer: The tokenizer to use text_column: The text column name (auto-detected if None) max_length: Maximum sequence length Returns: The prepared dataset """ # Auto-detect text column if not provided if text_column is None: columns = dataset.column_names text_column = next( (col for col in columns if 'text' in col.lower()), columns[0] ) def tokenize_function(examples): # Extract texts from examples if isinstance(examples, dict) and text_column in examples: texts = examples[text_column] else: texts = examples # Handle different text formats if isinstance(texts, str): texts = [texts] elif isinstance(texts, list): pass # Already a list else: texts = [str(texts)] # Tokenize return tokenizer( texts, truncation=True, padding="max_length", max_length=max_length, return_tensors="pt" ) # Apply tokenization tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names ) # Set format for PyTorch tokenized_dataset.set_format("torch") return tokenized_dataset class EmbeddingTrainer(nn.Module): """ A trainer class for embedding models with contrastive learning. """ def __init__(self, base_model: nn.Module, temperature: float = 0.07): super().__init__() self.base_model = base_model self.temperature = temperature self.dropout = nn.Dropout(0.1) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Forward pass through the model.""" outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask ) # Get pooled output (use CLS token or mean pooling) if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: pooled = outputs.pooler_output else: # Mean pooling last_hidden = outputs.last_hidden_state attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()) pooled = torch.sum(last_hidden * attention_mask_expanded, 1) / torch.clamp(attention_mask_expanded.sum(1), min=1e-9) return self.dropout(pooled) def compute_contrastive_loss(self, embeddings: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor: """ Compute contrastive loss for training embeddings. Args: embeddings: The embeddings to compute loss for labels: Optional labels for supervised contrastive learning Returns: The contrastive loss """ # Normalize embeddings embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) # Compute similarity matrix similarity_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature # Create labels (positive pairs are on the diagonal) batch_size = embeddings.size(0) labels = torch.arange(batch_size, device=embeddings.device) # Compute cross-entropy loss loss = torch.nn.functional.cross_entropy(similarity_matrix, labels) return loss def train_model_on_zero_gpu( model: nn.Module, tokenizer: AutoTokenizer, dataset: Dataset, epochs: int = 3, batch_size: int = 16, learning_rate: float = 2e-5, warmup_steps: int = 100, use_zero_gpu: bool = True ) -> Tuple[nn.Module, List[Dict[str, float]]]: """ Train the embedding model using Zero GPU if available. Args: model: The model to train tokenizer: The tokenizer dataset: The training dataset epochs: Number of training epochs batch_size: Training batch size learning_rate: Learning rate warmup_steps: Number of warmup steps use_zero_gpu: Whether to use Zero GPU Returns: Tuple of (trained_model, training_history) """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize trainer trainer = EmbeddingTrainer(model) trainer.to(device) # Create data loader def collate_fn(batch): return { 'input_ids': torch.stack([item['input_ids'] for item in batch]), 'attention_mask': torch.stack([item['attention_mask'] for item in batch]) } dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn ) # Setup optimizer and scheduler optimizer = AdamW(trainer.parameters(), lr=learning_rate, weight_decay=0.01) num_training_steps = epochs * len(dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps ) # Training history training_history = [] # Training loop trainer.train() for epoch in range(epochs): epoch_loss = 0.0 num_batches = 0 progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch_idx, batch in enumerate(progress_bar): # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Forward pass embeddings = trainer(**batch) loss = trainer.compute_contrastive_loss(embeddings) # Backward pass loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(trainer.parameters(), max_norm=1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Update metrics epoch_loss += loss.item() num_batches += 1 # Update progress bar current_lr = optimizer.param_groups[0]['lr'] progress_bar.set_postfix({ 'loss': f'{loss.item():.4f}', 'lr': f'{current_lr:.2e}', 'avg_loss': f'{epoch_loss/num_batches:.4f}' }) # Log training history if batch_idx % 10 == 0: # Log every 10 batches training_history.append({ 'epoch': epoch + 1, 'batch': batch_idx + 1, 'loss': loss.item(), 'learning_rate': current_lr, 'avg_loss': epoch_loss / num_batches }) # Log epoch summary avg_epoch_loss = epoch_loss / num_batches training_history.append({ 'epoch': epoch + 1, 'batch': num_batches, 'loss': avg_epoch_loss, 'learning_rate': optimizer.param_groups[0]['lr'], 'avg_loss': avg_epoch_loss, 'epoch_end': True }) print(f"Epoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}") return model, training_history def save_model_to_hub( model: nn.Module, tokenizer: AutoTokenizer, repo_id: str, token: str, private: bool = False ) -> str: """ Save the trained model to Hugging Face Hub. Args: model: The trained model tokenizer: The tokenizer repo_id: Repository ID token: Hugging Face token private: Whether the repository should be private Returns: The repository URL """ try: # Push model to hub model.push_to_hub( repo_id=repo_id, token=token, private=private, commit_message="Upload trained embedding model" ) # Push tokenizer to hub tokenizer.push_to_hub( repo_id=repo_id, token=token, private=private, commit_message="Upload tokenizer" ) return f"https://huggingface.co/{repo_id}" except Exception as e: raise Exception(f"Failed to save model to Hub: {str(e)}") def create_repository(repo_name: str, token: str, private: bool = False) -> str: """ Create a new repository on Hugging Face Hub. Args: repo_name: Name of the repository token: Hugging Face token private: Whether the repository should be private Returns: The repository URL """ try: api = HfApi(token=token) repo_url = api.create_repo( repo_id=repo_name, token=token, private=private, repo_type="model", exist_ok=True ) return repo_url except Exception as e: raise Exception(f"Failed to create repository: {str(e)}")