Phase-Technologies commited on
Commit
cf0b8ab
·
verified ·
1 Parent(s): e69f504

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +202 -133
  2. adapter_config.json +5 -5
  3. adapter_model.safetensors +1 -1
  4. inference.py +12 -6
  5. vocab.json +0 -0
README.md CHANGED
@@ -1,134 +1,203 @@
 
 
 
 
 
1
 
2
- # Contrastive Zero-Shot Shakespeare Classifier
3
-
4
- This project implements a lightweight Contrastive Transformer model for zero-shot text classification, specifically designed to operate efficiently within memory-constrained environments using LoRA (Low-Rank Adaptation) for fine-tuning.
5
-
6
- ## Motivation and Memory Optimization
7
-
8
- Initially, training larger transformer models led to `OutOfMemoryError` on the available GPU. To address this, a two-pronged approach was taken:
9
- 1. **Base Model Reduction**: The core `ContrastiveTransformer` architecture was significantly scaled down to `dim=64`, `depth=2`, and `heads=2`. This drastically reduced the base model's memory footprint, allowing it to be loaded onto the GPU.
10
- 2. **LoRA (Low-Rank Adaptation)**: To enable efficient fine-tuning without requiring extensive memory for gradients and optimizer states, LoRA was applied. This technique adds small, trainable low-rank matrices to the existing linear layers, allowing us to train only a small percentage of the model's parameters. In this implementation, only **1.7430%** of the total parameters were trainable, making the training process highly memory-efficient.
11
-
12
- ## Model Architecture
13
-
14
- The model is a custom `ContrastiveTransformer` built from scratch, composed of:
15
- * **Token and Positional Embeddings**: Map input tokens and their positions to dense vector representations.
16
- * **Transformer Blocks**: Multiple layers, each containing a Multi-Head Attention mechanism and a SwiGLU-activated Feed-Forward Network.
17
- * **SwiGLU Activation**: A modern activation function for the Feed-Forward Network, providing improved performance.
18
- * **Projection Layer**: Maps the final pooled embeddings to the output dimension.
19
-
20
- LoRA layers were injected into the following modules to allow for efficient adaptation:
21
- * `MultiheadAttention` linear layers (query, key, value, output projections)
22
- * Feed-Forward Network's linear layers (`ff.0`, `ff.3`)
23
- * Final projection layer (`proj`)
24
-
25
- ## Training
26
-
27
- The model was trained for contrastive zero-shot classification using several datasets:
28
- * `Xerv-AI/Conversational-2K-SimpleEnglish`
29
- * `Xerv-AI/Savage-Responses-2K`
30
- * `Xerv-AI/GRAD`
31
- * `tiny_shakespeare` (a text dataset extracted from a raw text file)
32
-
33
- These datasets were combined and tokenized using a custom vocabulary. The model was trained for `10` epochs with an AdamW optimizer and a Cosine Annealing learning rate scheduler.
34
-
35
- ## Inference Usage
36
-
37
- To use the trained model for zero-shot classification, follow these steps:
38
-
39
- 1. **Download the model artifacts**: The model weights and configuration are available on Hugging Face:
40
- `https://huggingface.co/Phase-Technologies/contrastive-zeroshot-shakespeare`
41
-
42
- 2. **Load the model and vocabulary**: You can use the `inference.py` script provided in the repository.
43
-
44
- ```python
45
- import torch, json, re, torch.nn as nn, torch.nn.functional as F
46
- from peft import PeftModel, LoraConfig, get_peft_model, TaskType
47
- from pathlib import Path
48
-
49
- # Define the custom model architecture (same as in training)
50
- class SwiGLU(nn.Module):
51
- def forward(self, x):
52
- x, gate = x.chunk(2, dim=-1);
53
- return F.silu(gate) * x
54
-
55
- class TransformerBlock(nn.Module):
56
- def __init__(self, dim, heads=2, dropout=0.1):
57
- super().__init__()
58
- self.norm1 = nn.LayerNorm(dim)
59
- self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
60
- self.norm2 = nn.LayerNorm(dim)
61
- self.ff = nn.Sequential(
62
- nn.Linear(dim, dim * 4 * 2),
63
- SwiGLU(),
64
- nn.Linear(dim * 4, dim)
65
- )
66
-
67
- def forward(self, x, mask=None):
68
- attn_out, _ = self.attn(x, x, x, key_padding_mask=mask)
69
- x = self.norm1(x + attn_out)
70
- ff_out = self.ff(x)
71
- return self.norm2(x + ff_out)
72
-
73
- class ContrastiveTransformer(nn.Module):
74
- def __init__(self, vocab, dim=64, depth=2, heads=2, max_seq=256):
75
- super().__init__()
76
- self.token_emb = nn.Embedding(len(vocab), dim)
77
- self.pos_emb = nn.Embedding(max_seq, dim)
78
- self.blocks = nn.ModuleList([TransformerBlock(dim, heads) for _ in range(depth)])
79
- self.ln_f = nn.LayerNorm(dim)
80
- self.proj = nn.Linear(dim, dim, bias=False)
81
-
82
- def forward(self, input_ids, attention_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None):
83
- if inputs_embeds is not None:
84
- x = inputs_embeds
85
- else:
86
- x = self.token_emb(input_ids)
87
-
88
- t = x.shape[1]
89
- pos_emb = self.pos_emb(torch.arange(t, device=x.device))
90
- x = x + pos_emb
91
- mask = (input_ids == 0)
92
- for block in self.blocks:
93
- x = block(x, mask)
94
- x = self.ln_f(x)
95
- pooled = x.mean(dim=1)
96
- return self.proj(pooled)
97
-
98
- def load_model(vocab_path, ckpt_path):
99
- vocab = json.load(open(vocab_path))
100
- config = json.load(open(Path(ckpt_path).parent / "config.json"))
101
- base_model = ContrastiveTransformer(vocab, dim=config["dim"], depth=config["depth"], heads=config["heads"])
102
- model = PeftModel.from_pretrained(base_model, Path(ckpt_path).parent)
103
- model.eval()
104
- return model, vocab
105
-
106
- def predict(text, candidate_labels, model, vocab):
107
- device = next(model.parameters()).device
108
- def enc(t):
109
- tokens = [vocab.get(w, vocab.get("<UNK>", 1)) for w in re.findall(r"\w+", t.lower())]
110
- if not tokens:
111
- tokens = [vocab.get("<UNK>", 1)]
112
- input_ids = torch.tensor([tokens[:256] + [0]*(256-len(tokens[:256]))], device=device)
113
- with torch.no_grad():
114
- return F.normalize(model.forward(input_ids=input_ids), dim=-1)
115
- text_emb = enc(text)
116
- label_embs = torch.cat([enc(lab) for lab in candidate_labels])
117
- sims = F.cosine_similarity(text_emb, label_embs)
118
- best = sims.argmax().item()
119
- return candidate_labels[best], float(sims[best])
120
-
121
- # Example Usage:
122
- # Assuming model_path is the directory where you saved the model (e.g., '/content/contrastive-zeroshot-shakespeare')
123
- model_dir = Path("/content/contrastive-zeroshot-shakespeare") # Or download and specify your path
124
- model, vocab = load_model(model_dir / "vocab.json", model_dir / "adapter_model.safetensors")
125
-
126
- test_text = "to be or not to be that is the question"
127
- candidate_labels = ["shakespeare play", "math proof", "cooking recipe", "gaming victory", "climate news"]
128
- pred, conf = predict(test_text, candidate_labels, model, vocab)
129
- print(f"Test prediction: shakespeare play (0.995)")
130
- ```
131
-
132
- ### Hugging Face Repository
133
- The model and related files are available on Hugging Face:
134
- [Phase-Technologies/contrastive-zeroshot-shakespeare](https://huggingface.co/Phase-Technologies/contrastive-zeroshot-shakespeare)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ tags:
4
+ - lora
5
+ ---
6
 
7
+ # Model Card for Model ID
8
+
9
+ <!-- Provide a quick summary of what the model is/does. -->
10
+
11
+
12
+
13
+ ## Model Details
14
+
15
+ ### Model Description
16
+
17
+ <!-- Provide a longer summary of what this model is. -->
18
+
19
+
20
+
21
+ - **Developed by:** [More Information Needed]
22
+ - **Funded by [optional]:** [More Information Needed]
23
+ - **Shared by [optional]:** [More Information Needed]
24
+ - **Model type:** [More Information Needed]
25
+ - **Language(s) (NLP):** [More Information Needed]
26
+ - **License:** [More Information Needed]
27
+ - **Finetuned from model [optional]:** [More Information Needed]
28
+
29
+ ### Model Sources [optional]
30
+
31
+ <!-- Provide the basic links for the model. -->
32
+
33
+ - **Repository:** [More Information Needed]
34
+ - **Paper [optional]:** [More Information Needed]
35
+ - **Demo [optional]:** [More Information Needed]
36
+
37
+ ## Uses
38
+
39
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
40
+
41
+ ### Direct Use
42
+
43
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
44
+
45
+ [More Information Needed]
46
+
47
+ ### Downstream Use [optional]
48
+
49
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
50
+
51
+ [More Information Needed]
52
+
53
+ ### Out-of-Scope Use
54
+
55
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
56
+
57
+ [More Information Needed]
58
+
59
+ ## Bias, Risks, and Limitations
60
+
61
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
62
+
63
+ [More Information Needed]
64
+
65
+ ### Recommendations
66
+
67
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
68
+
69
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
70
+
71
+ ## How to Get Started with the Model
72
+
73
+ Use the code below to get started with the model.
74
+
75
+ [More Information Needed]
76
+
77
+ ## Training Details
78
+
79
+ ### Training Data
80
+
81
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
82
+
83
+ [More Information Needed]
84
+
85
+ ### Training Procedure
86
+
87
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
88
+
89
+ #### Preprocessing [optional]
90
+
91
+ [More Information Needed]
92
+
93
+
94
+ #### Training Hyperparameters
95
+
96
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
97
+
98
+ #### Speeds, Sizes, Times [optional]
99
+
100
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
101
+
102
+ [More Information Needed]
103
+
104
+ ## Evaluation
105
+
106
+ <!-- This section describes the evaluation protocols and provides the results. -->
107
+
108
+ ### Testing Data, Factors & Metrics
109
+
110
+ #### Testing Data
111
+
112
+ <!-- This should link to a Dataset Card if possible. -->
113
+
114
+ [More Information Needed]
115
+
116
+ #### Factors
117
+
118
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
119
+
120
+ [More Information Needed]
121
+
122
+ #### Metrics
123
+
124
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
125
+
126
+ [More Information Needed]
127
+
128
+ ### Results
129
+
130
+ [More Information Needed]
131
+
132
+ #### Summary
133
+
134
+
135
+
136
+ ## Model Examination [optional]
137
+
138
+ <!-- Relevant interpretability work for the model goes here -->
139
+
140
+ [More Information Needed]
141
+
142
+ ## Environmental Impact
143
+
144
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
145
+
146
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
147
+
148
+ - **Hardware Type:** [More Information Needed]
149
+ - **Hours used:** [More Information Needed]
150
+ - **Cloud Provider:** [More Information Needed]
151
+ - **Compute Region:** [More Information Needed]
152
+ - **Carbon Emitted:** [More Information Needed]
153
+
154
+ ## Technical Specifications [optional]
155
+
156
+ ### Model Architecture and Objective
157
+
158
+ [More Information Needed]
159
+
160
+ ### Compute Infrastructure
161
+
162
+ [More Information Needed]
163
+
164
+ #### Hardware
165
+
166
+ [More Information Needed]
167
+
168
+ #### Software
169
+
170
+ [More Information Needed]
171
+
172
+ ## Citation [optional]
173
+
174
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
175
+
176
+ **BibTeX:**
177
+
178
+ [More Information Needed]
179
+
180
+ **APA:**
181
+
182
+ [More Information Needed]
183
+
184
+ ## Glossary [optional]
185
+
186
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
187
+
188
+ [More Information Needed]
189
+
190
+ ## More Information [optional]
191
+
192
+ [More Information Needed]
193
+
194
+ ## Model Card Authors [optional]
195
+
196
+ [More Information Needed]
197
+
198
+ ## Model Card Contact
199
+
200
+ [More Information Needed]
201
+ ### Framework versions
202
+
203
+ - PEFT 0.18.0
adapter_config.json CHANGED
@@ -29,13 +29,13 @@
29
  "rank_pattern": {},
30
  "revision": null,
31
  "target_modules": [
32
- "proj",
33
- "q_proj",
34
  "ff.0",
 
35
  "ff.3",
36
- "v_proj",
37
- "out_proj",
38
- "k_proj"
39
  ],
40
  "target_parameters": null,
41
  "task_type": "FEATURE_EXTRACTION",
 
29
  "rank_pattern": {},
30
  "revision": null,
31
  "target_modules": [
32
+ "v_proj",
 
33
  "ff.0",
34
+ "q_proj",
35
  "ff.3",
36
+ "proj",
37
+ "k_proj",
38
+ "out_proj"
39
  ],
40
  "target_parameters": null,
41
  "task_type": "FEATURE_EXTRACTION",
adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6340f16eb6af5984692d1ccdb5b664963ba5e2c9c1b4edc3c9a2fa3b8878f75
3
  size 71184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6766d4647220efc90394eb44ee30e6c07eaa8b15cedc835586753a1ac589a775
3
  size 71184
inference.py CHANGED
@@ -1,6 +1,8 @@
1
 
2
  import torch, json, re, torch.nn as nn, torch.nn.functional as F
3
  from peft import PeftModel, LoraConfig, get_peft_model, TaskType
 
 
4
 
5
  class SwiGLU(nn.Module):
6
  def forward(self, x):
@@ -55,14 +57,18 @@ class ContrastiveTransformer(nn.Module):
55
  # The training code normalises in the training loop.
56
  return self.proj(pooled)
57
 
58
- def load_model(vocab_path, ckpt_path):
 
 
 
 
59
  vocab = json.load(open(vocab_path))
60
- # Ensure model parameters match the saved config
61
- config = json.load(open(Path(ckpt_path).parent / "config.json"))
62
- base_model = ContrastiveTransformer(vocab, dim=config["dim"], depth=config["depth"], heads=config["heads"])
63
 
64
- # Load the PEFT model from the directory
65
- model = PeftModel.from_pretrained(base_model, Path(ckpt_path).parent)
66
  model.eval()
67
  return model, vocab
68
 
 
1
 
2
  import torch, json, re, torch.nn as nn, torch.nn.functional as F
3
  from peft import PeftModel, LoraConfig, get_peft_model, TaskType
4
+ from huggingface_hub import hf_hub_download # Import for downloading files from Hugging Face Hub
5
+ from pathlib import Path
6
 
7
  class SwiGLU(nn.Module):
8
  def forward(self, x):
 
57
  # The training code normalises in the training loop.
58
  return self.proj(pooled)
59
 
60
+ def load_model(repo_id):
61
+ # Download vocab.json and config.json from the Hugging Face Hub
62
+ vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json")
63
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
64
+
65
  vocab = json.load(open(vocab_path))
66
+ config = json.load(open(config_path))
67
+
68
+ base_model = ContrastiveTransformer(vocab, dim=config["dim"], depth=config["depth"], heads=config["heads"]) # Pass vocab object instead of its length
69
 
70
+ # Load the PEFT model directly from the Hugging Face Hub
71
+ model = PeftModel.from_pretrained(base_model, repo_id) # Pass repo_id directly
72
  model.eval()
73
  return model, vocab
74
 
vocab.json CHANGED
The diff for this file is too large to render. See raw diff