easy-green-lab / utils.py
SuperPauly's picture
Update Gradio app with multiple files
9c928ab verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler
from datasets import load_dataset, Dataset
from huggingface_hub import HfApi
from tqdm.auto import tqdm
import numpy as np
from typing import Dict, List, Tuple, Any
def load_embedding_model(model_name_or_path: str) -> Tuple[nn.Module, AutoTokenizer]:
"""
Load an embedding model from Hugging Face.
Args:
model_name_or_path: The model name or path on Hugging Face
Returns:
Tuple of (model, tokenizer)
"""
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModel.from_pretrained(model_name_or_path)
return model, tokenizer
except Exception as e:
raise Exception(f"Failed to load model {model_name_or_path}: {str(e)}")
def load_huggingface_dataset(dataset_name: str, split: str = "train") -> Dataset:
"""
Load a dataset from Hugging Face.
Args:
dataset_name: The dataset name on Hugging Face
split: The dataset split to load
Returns:
The loaded dataset
"""
try:
dataset = load_dataset(dataset_name, split=split)
return dataset
except Exception as e:
raise Exception(f"Failed to load dataset {dataset_name}: {str(e)}")
def prepare_dataset_for_training(dataset: Dataset, tokenizer: AutoTokenizer,
text_column: str = None, max_length: int = 512) -> Dataset:
"""
Prepare dataset for embedding training by tokenizing texts.
Args:
dataset: The dataset to prepare
tokenizer: The tokenizer to use
text_column: The text column name (auto-detected if None)
max_length: Maximum sequence length
Returns:
The prepared dataset
"""
# Auto-detect text column if not provided
if text_column is None:
columns = dataset.column_names
text_column = next(
(col for col in columns if 'text' in col.lower()),
columns[0]
)
def tokenize_function(examples):
# Extract texts from examples
if isinstance(examples, dict) and text_column in examples:
texts = examples[text_column]
else:
texts = examples
# Handle different text formats
if isinstance(texts, str):
texts = [texts]
elif isinstance(texts, list):
pass # Already a list
else:
texts = [str(texts)]
# Tokenize
return tokenizer(
texts,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt"
)
# Apply tokenization
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
# Set format for PyTorch
tokenized_dataset.set_format("torch")
return tokenized_dataset
class EmbeddingTrainer(nn.Module):
"""
A trainer class for embedding models with contrastive learning.
"""
def __init__(self, base_model: nn.Module, temperature: float = 0.07):
super().__init__()
self.base_model = base_model
self.temperature = temperature
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Forward pass through the model."""
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask
)
# Get pooled output (use CLS token or mean pooling)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
pooled = outputs.pooler_output
else:
# Mean pooling
last_hidden = outputs.last_hidden_state
attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size())
pooled = torch.sum(last_hidden * attention_mask_expanded, 1) / torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
return self.dropout(pooled)
def compute_contrastive_loss(self, embeddings: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor:
"""
Compute contrastive loss for training embeddings.
Args:
embeddings: The embeddings to compute loss for
labels: Optional labels for supervised contrastive learning
Returns:
The contrastive loss
"""
# Normalize embeddings
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# Compute similarity matrix
similarity_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature
# Create labels (positive pairs are on the diagonal)
batch_size = embeddings.size(0)
labels = torch.arange(batch_size, device=embeddings.device)
# Compute cross-entropy loss
loss = torch.nn.functional.cross_entropy(similarity_matrix, labels)
return loss
def train_model_on_zero_gpu(
model: nn.Module,
tokenizer: AutoTokenizer,
dataset: Dataset,
epochs: int = 3,
batch_size: int = 16,
learning_rate: float = 2e-5,
warmup_steps: int = 100,
use_zero_gpu: bool = True
) -> Tuple[nn.Module, List[Dict[str, float]]]:
"""
Train the embedding model using Zero GPU if available.
Args:
model: The model to train
tokenizer: The tokenizer
dataset: The training dataset
epochs: Number of training epochs
batch_size: Training batch size
learning_rate: Learning rate
warmup_steps: Number of warmup steps
use_zero_gpu: Whether to use Zero GPU
Returns:
Tuple of (trained_model, training_history)
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize trainer
trainer = EmbeddingTrainer(model)
trainer.to(device)
# Create data loader
def collate_fn(batch):
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch])
}
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
# Setup optimizer and scheduler
optimizer = AdamW(trainer.parameters(), lr=learning_rate, weight_decay=0.01)
num_training_steps = epochs * len(dataloader)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps
)
# Training history
training_history = []
# Training loop
trainer.train()
for epoch in range(epochs):
epoch_loss = 0.0
num_batches = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch_idx, batch in enumerate(progress_bar):
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
embeddings = trainer(**batch)
loss = trainer.compute_contrastive_loss(embeddings)
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(trainer.parameters(), max_norm=1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Update metrics
epoch_loss += loss.item()
num_batches += 1
# Update progress bar
current_lr = optimizer.param_groups[0]['lr']
progress_bar.set_postfix({
'loss': f'{loss.item():.4f}',
'lr': f'{current_lr:.2e}',
'avg_loss': f'{epoch_loss/num_batches:.4f}'
})
# Log training history
if batch_idx % 10 == 0: # Log every 10 batches
training_history.append({
'epoch': epoch + 1,
'batch': batch_idx + 1,
'loss': loss.item(),
'learning_rate': current_lr,
'avg_loss': epoch_loss / num_batches
})
# Log epoch summary
avg_epoch_loss = epoch_loss / num_batches
training_history.append({
'epoch': epoch + 1,
'batch': num_batches,
'loss': avg_epoch_loss,
'learning_rate': optimizer.param_groups[0]['lr'],
'avg_loss': avg_epoch_loss,
'epoch_end': True
})
print(f"Epoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
return model, training_history
def save_model_to_hub(
model: nn.Module,
tokenizer: AutoTokenizer,
repo_id: str,
token: str,
private: bool = False
) -> str:
"""
Save the trained model to Hugging Face Hub.
Args:
model: The trained model
tokenizer: The tokenizer
repo_id: Repository ID
token: Hugging Face token
private: Whether the repository should be private
Returns:
The repository URL
"""
try:
# Push model to hub
model.push_to_hub(
repo_id=repo_id,
token=token,
private=private,
commit_message="Upload trained embedding model"
)
# Push tokenizer to hub
tokenizer.push_to_hub(
repo_id=repo_id,
token=token,
private=private,
commit_message="Upload tokenizer"
)
return f"https://huggingface.co/{repo_id}"
except Exception as e:
raise Exception(f"Failed to save model to Hub: {str(e)}")
def create_repository(repo_name: str, token: str, private: bool = False) -> str:
"""
Create a new repository on Hugging Face Hub.
Args:
repo_name: Name of the repository
token: Hugging Face token
private: Whether the repository should be private
Returns:
The repository URL
"""
try:
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_name,
token=token,
private=private,
repo_type="model",
exist_ok=True
)
return repo_url
except Exception as e:
raise Exception(f"Failed to create repository: {str(e)}")