YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Contrastive Zero-Shot Shakespeare Classifier

This project implements a lightweight Contrastive Transformer model for zero-shot text classification, specifically designed to operate efficiently within memory-constrained environments using LoRA (Low-Rank Adaptation) for fine-tuning.

Motivation and Memory Optimization

Initially, training larger transformer models led to OutOfMemoryError on the available GPU. To address this, a two-pronged approach was taken:

  1. Base Model Reduction: The core ContrastiveTransformer architecture was significantly scaled down to dim=64, depth=2, and heads=2. This drastically reduced the base model's memory footprint, allowing it to be loaded onto the GPU.
  2. LoRA (Low-Rank Adaptation): To enable efficient fine-tuning without requiring extensive memory for gradients and optimizer states, LoRA was applied. This technique adds small, trainable low-rank matrices to the existing linear layers, allowing us to train only a small percentage of the model's parameters. In this implementation, only 1.7430% of the total parameters were trainable, making the training process highly memory-efficient.

Model Architecture

The model is a custom ContrastiveTransformer built from scratch, composed of:

  • Token and Positional Embeddings: Map input tokens and their positions to dense vector representations.
  • Transformer Blocks: Multiple layers, each containing a Multi-Head Attention mechanism and a SwiGLU-activated Feed-Forward Network.
  • SwiGLU Activation: A modern activation function for the Feed-Forward Network, providing improved performance.
  • Projection Layer: Maps the final pooled embeddings to the output dimension.

LoRA layers were injected into the following modules to allow for efficient adaptation:

  • MultiheadAttention linear layers (query, key, value, output projections)
  • Feed-Forward Network's linear layers (ff.0, ff.3)
  • Final projection layer (proj)

Training

The model was trained for contrastive zero-shot classification using several datasets:

  • Xerv-AI/Conversational-2K-SimpleEnglish
  • Xerv-AI/Savage-Responses-2K
  • Xerv-AI/GRAD
  • tiny_shakespeare (a text dataset extracted from a raw text file)

These datasets were combined and tokenized using a custom vocabulary. The model was trained for 10 epochs with an AdamW optimizer and a Cosine Annealing learning rate scheduler.

Inference Usage

To use the trained model for zero-shot classification, follow these steps:

  1. Install necessary libraries: Make sure you have peft and huggingface_hub installed: pip install peft huggingface_hub torch

  2. Load the model and vocabulary: The model and its configuration are available on Hugging Face. You can use the inference.py script provided in the repository (or copy the code below).

    import torch, json, re, torch.nn as nn, torch.nn.functional as F
    from peft import PeftModel, LoraConfig, get_peft_model, TaskType
    from huggingface_hub import hf_hub_download
    from pathlib import Path
    
    # Define the custom model architecture (same as in training)
    class SwiGLU(nn.Module):
        def forward(self, x):
            x, gate = x.chunk(2, dim=-1);
            return F.silu(gate) * x
    
    class TransformerBlock(nn.Module):
        def __init__(self, dim, heads=2, dropout=0.1):
            super().__init__()
            self.norm1 = nn.LayerNorm(dim)
            self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
            self.norm2 = nn.LayerNorm(dim)
            self.ff = nn.Sequential(
                nn.Linear(dim, dim * 4 * 2),
                SwiGLU(),
                nn.Linear(dim * 4, dim)
            )
    
        def forward(self, x, mask=None):
            attn_out, _ = self.attn(x, x, x, key_padding_mask=mask)
            x = self.norm1(x + attn_out)
            ff_out = self.ff(x)
            return self.norm2(x + ff_out)
    
    class ContrastiveTransformer(nn.Module):
        def __init__(self, vocab, dim=64, depth=2, heads=2, max_seq=256):
            super().__init__()
            self.token_emb = nn.Embedding(len(vocab), dim)
            self.pos_emb = nn.Embedding(max_seq, dim)
            self.blocks = nn.ModuleList([TransformerBlock(dim, heads) for _ in range(depth)])
            self.ln_f = nn.LayerNorm(dim)
            self.proj = nn.Linear(dim, dim, bias=False)
    
        def forward(self, input_ids, attention_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None):
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.token_emb(input_ids)
    
            t = x.shape[1]
            pos_emb = self.pos_emb(torch.arange(t, device=x.device))
            x = x + pos_emb
            mask = (input_ids == 0)
            for block in self.blocks:
                x = block(x, mask)
            x = self.ln_f(x)
            pooled = x.mean(dim=1)
            return self.proj(pooled)
    
    def load_model(repo_id):
        # Download vocab.json and config.json from the Hugging Face Hub
        vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
    
        vocab = json.load(open(vocab_path))
        config = json.load(open(config_path))
    
        base_model = ContrastiveTransformer(vocab, dim=config["dim"], depth=config["depth"], heads=config["heads"]) # Pass vocab object instead of its length
        
        # Load the PEFT model directly from the Hugging Face Hub
        model = PeftModel.from_pretrained(base_model, repo_id) # Pass repo_id directly
        model.eval()
        return model, vocab
    
    def predict(text, candidate_labels, model, vocab):
        device = next(model.parameters()).device
        def enc(t):
            tokens = [vocab.get(w, vocab.get("<UNK>", 1)) for w in re.findall(r"\w+", t.lower())]
            if not tokens:
                tokens = [vocab.get("<UNK>", 1)]
            input_ids = torch.tensor([tokens[:256] + [0]*(256-len(tokens[:256]))], device=device)
            with torch.no_grad():
                return F.normalize(model.forward(input_ids=input_ids), dim=-1)
        text_emb = enc(text)
        label_embs = torch.cat([enc(lab) for lab in candidate_labels])
        sims = F.cosine_similarity(text_emb, label_embs)
        best = sims.argmax().item()
        return candidate_labels[best], float(sims[best])
    
    # Example Usage:
    model_repo_id = "Phase-Technologies/contrastive-zeroshot-shakespeare" # Your Hugging Face model ID
    model, vocab = load_model(model_repo_id)
    
    test_text = "to be or not to be that is the question"
    candidate_labels = ["shakespeare play", "math proof", "cooking recipe", "gaming victory", "climate news"]
    pred, conf = predict(test_text, candidate_labels, model, vocab)
    print(f"Test prediction: cooking recipe (0.999)")
    

Hugging Face Repository

The model and related files are available on Hugging Face: Phase-Technologies/contrastive-zeroshot-shakespeare

Downloads last month
46
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support