Contrastive Zero-Shot Shakespeare Classifier
This project implements a lightweight Contrastive Transformer model for zero-shot text classification, specifically designed to operate efficiently within memory-constrained environments using LoRA (Low-Rank Adaptation) for fine-tuning.
Motivation and Memory Optimization
Initially, training larger transformer models led to OutOfMemoryError on the available GPU. To address this, a two-pronged approach was taken:
- Base Model Reduction: The core
ContrastiveTransformerarchitecture was significantly scaled down todim=64,depth=2, andheads=2. This drastically reduced the base model's memory footprint, allowing it to be loaded onto the GPU. - LoRA (Low-Rank Adaptation): To enable efficient fine-tuning without requiring extensive memory for gradients and optimizer states, LoRA was applied. This technique adds small, trainable low-rank matrices to the existing linear layers, allowing us to train only a small percentage of the model's parameters. In this implementation, only 1.7430% of the total parameters were trainable, making the training process highly memory-efficient.
Model Architecture
The model is a custom ContrastiveTransformer built from scratch, composed of:
- Token and Positional Embeddings: Map input tokens and their positions to dense vector representations.
- Transformer Blocks: Multiple layers, each containing a Multi-Head Attention mechanism and a SwiGLU-activated Feed-Forward Network.
- SwiGLU Activation: A modern activation function for the Feed-Forward Network, providing improved performance.
- Projection Layer: Maps the final pooled embeddings to the output dimension.
LoRA layers were injected into the following modules to allow for efficient adaptation:
MultiheadAttentionlinear layers (query, key, value, output projections)- Feed-Forward Network's linear layers (
ff.0,ff.3) - Final projection layer (
proj)
Training
The model was trained for contrastive zero-shot classification using several datasets:
Xerv-AI/Conversational-2K-SimpleEnglishXerv-AI/Savage-Responses-2KXerv-AI/GRADtiny_shakespeare(a text dataset extracted from a raw text file)
These datasets were combined and tokenized using a custom vocabulary. The model was trained for 10 epochs with an AdamW optimizer and a Cosine Annealing learning rate scheduler.
Inference Usage
To use the trained model for zero-shot classification, follow these steps:
Install necessary libraries: Make sure you have
peftandhuggingface_hubinstalled:pip install peft huggingface_hub torchLoad the model and vocabulary: The model and its configuration are available on Hugging Face. You can use the
inference.pyscript provided in the repository (or copy the code below).import torch, json, re, torch.nn as nn, torch.nn.functional as F from peft import PeftModel, LoraConfig, get_peft_model, TaskType from huggingface_hub import hf_hub_download from pathlib import Path # Define the custom model architecture (same as in training) class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1); return F.silu(gate) * x class TransformerBlock(nn.Module): def __init__(self, dim, heads=2, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True) self.norm2 = nn.LayerNorm(dim) self.ff = nn.Sequential( nn.Linear(dim, dim * 4 * 2), SwiGLU(), nn.Linear(dim * 4, dim) ) def forward(self, x, mask=None): attn_out, _ = self.attn(x, x, x, key_padding_mask=mask) x = self.norm1(x + attn_out) ff_out = self.ff(x) return self.norm2(x + ff_out) class ContrastiveTransformer(nn.Module): def __init__(self, vocab, dim=64, depth=2, heads=2, max_seq=256): super().__init__() self.token_emb = nn.Embedding(len(vocab), dim) self.pos_emb = nn.Embedding(max_seq, dim) self.blocks = nn.ModuleList([TransformerBlock(dim, heads) for _ in range(depth)]) self.ln_f = nn.LayerNorm(dim) self.proj = nn.Linear(dim, dim, bias=False) def forward(self, input_ids, attention_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None): if inputs_embeds is not None: x = inputs_embeds else: x = self.token_emb(input_ids) t = x.shape[1] pos_emb = self.pos_emb(torch.arange(t, device=x.device)) x = x + pos_emb mask = (input_ids == 0) for block in self.blocks: x = block(x, mask) x = self.ln_f(x) pooled = x.mean(dim=1) return self.proj(pooled) def load_model(repo_id): # Download vocab.json and config.json from the Hugging Face Hub vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") vocab = json.load(open(vocab_path)) config = json.load(open(config_path)) base_model = ContrastiveTransformer(vocab, dim=config["dim"], depth=config["depth"], heads=config["heads"]) # Pass vocab object instead of its length # Load the PEFT model directly from the Hugging Face Hub model = PeftModel.from_pretrained(base_model, repo_id) # Pass repo_id directly model.eval() return model, vocab def predict(text, candidate_labels, model, vocab): device = next(model.parameters()).device def enc(t): tokens = [vocab.get(w, vocab.get("<UNK>", 1)) for w in re.findall(r"\w+", t.lower())] if not tokens: tokens = [vocab.get("<UNK>", 1)] input_ids = torch.tensor([tokens[:256] + [0]*(256-len(tokens[:256]))], device=device) with torch.no_grad(): return F.normalize(model.forward(input_ids=input_ids), dim=-1) text_emb = enc(text) label_embs = torch.cat([enc(lab) for lab in candidate_labels]) sims = F.cosine_similarity(text_emb, label_embs) best = sims.argmax().item() return candidate_labels[best], float(sims[best]) # Example Usage: model_repo_id = "Phase-Technologies/contrastive-zeroshot-shakespeare" # Your Hugging Face model ID model, vocab = load_model(model_repo_id) test_text = "to be or not to be that is the question" candidate_labels = ["shakespeare play", "math proof", "cooking recipe", "gaming victory", "climate news"] pred, conf = predict(test_text, candidate_labels, model, vocab) print(f"Test prediction: cooking recipe (0.999)")
Hugging Face Repository
The model and related files are available on Hugging Face: Phase-Technologies/contrastive-zeroshot-shakespeare
- Downloads last month
- 46