Upload model
Browse files- config.json +26 -0
- configuration_compression.py +21 -0
- model.safetensors +3 -0
- modeling_compression.py +132 -0
config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CompressionModel"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_compression.CompressionConfig",
|
| 7 |
+
"AutoModel": "modeling_compression.CompressionModel"
|
| 8 |
+
},
|
| 9 |
+
"compression_sizes": [
|
| 10 |
+
512,
|
| 11 |
+
256,
|
| 12 |
+
128,
|
| 13 |
+
64,
|
| 14 |
+
32
|
| 15 |
+
],
|
| 16 |
+
"dropout": 0.1,
|
| 17 |
+
"input_size": 768,
|
| 18 |
+
"loss_k_vals": [
|
| 19 |
+
10,
|
| 20 |
+
100,
|
| 21 |
+
256
|
| 22 |
+
],
|
| 23 |
+
"model_type": "compression_head",
|
| 24 |
+
"torch_dtype": "float32",
|
| 25 |
+
"transformers_version": "4.38.2"
|
| 26 |
+
}
|
configuration_compression.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class CompressionConfig(PretrainedConfig):
|
| 5 |
+
model_type = "compression_head"
|
| 6 |
+
|
| 7 |
+
def __init__(self,
|
| 8 |
+
input_size: int = 768,
|
| 9 |
+
compression_sizes: List[int] = [512, 256, 128, 64, 32],
|
| 10 |
+
dropout: float = 0.1,
|
| 11 |
+
loss_k_vals: List[int] = [],
|
| 12 |
+
**kwargs
|
| 13 |
+
):
|
| 14 |
+
|
| 15 |
+
self.input_size = input_size
|
| 16 |
+
self.compression_sizes = compression_sizes
|
| 17 |
+
self.dropout = dropout
|
| 18 |
+
self.loss_k_vals = loss_k_vals
|
| 19 |
+
|
| 20 |
+
super().__init__(**kwargs)
|
| 21 |
+
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d62042eb9cbd70af97e6e2abbcfe3fa25972b969d17c421ad173348fee8b4ba
|
| 3 |
+
size 10557544
|
modeling_compression.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from typing import Tuple, Optional, List
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
from .configuration_compression import CompressionConfig
|
| 12 |
+
|
| 13 |
+
def cosine_pairwise(embeddings):
|
| 14 |
+
return F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
|
| 15 |
+
|
| 16 |
+
def cov(tensor, rowvar=True, bias=False):
|
| 17 |
+
"""Estimate a covariance matrix (np.cov)"""
|
| 18 |
+
tensor = tensor if rowvar else tensor.transpose(-1, -2)
|
| 19 |
+
tensor = tensor - tensor.mean(dim=-1, keepdim=True)
|
| 20 |
+
factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
|
| 21 |
+
return factor * tensor @ tensor.transpose(-1, -2).conj()
|
| 22 |
+
|
| 23 |
+
def remove_diag(x):
|
| 24 |
+
n = x.shape[0]
|
| 25 |
+
return x.masked_select(~torch.eye(n, dtype=bool, device=x.device)).view(n, n - 1)
|
| 26 |
+
|
| 27 |
+
def corrcoef(tensor, rowvar=True):
|
| 28 |
+
"""Get Pearson product-moment correlation coefficients (np.corrcoef)"""
|
| 29 |
+
covariance = cov(tensor, rowvar=rowvar)
|
| 30 |
+
variance = covariance.diagonal(0, -1, -2)
|
| 31 |
+
if variance.is_complex():
|
| 32 |
+
variance = variance.real
|
| 33 |
+
stddev = variance.sqrt()
|
| 34 |
+
covariance /= stddev.unsqueeze(-1)
|
| 35 |
+
covariance /= stddev.unsqueeze(-2)
|
| 36 |
+
if covariance.is_complex():
|
| 37 |
+
covariance.real.clip_(-1, 1)
|
| 38 |
+
covariance.imag.clip_(-1, 1)
|
| 39 |
+
else:
|
| 40 |
+
covariance.clip_(-1, 1)
|
| 41 |
+
return covariance
|
| 42 |
+
|
| 43 |
+
def compute_correlation(base_sims, compressed_sims, rm_diag=True):
|
| 44 |
+
if rm_diag:
|
| 45 |
+
base_sims = remove_diag(base_sims)
|
| 46 |
+
compressed_sims = remove_diag(compressed_sims)
|
| 47 |
+
|
| 48 |
+
inputs = torch.stack([base_sims,
|
| 49 |
+
compressed_sims], dim=1)
|
| 50 |
+
return (1-corrcoef(inputs)[:, 0, 1]).mean()
|
| 51 |
+
|
| 52 |
+
def loss_function(base_sims, compressed_sims, k_vals):
|
| 53 |
+
outputs = [compute_correlation(base_sims, compressed_sims)]
|
| 54 |
+
|
| 55 |
+
if k_vals:
|
| 56 |
+
base_ranks = base_sims.argsort(-1, descending=True)[:, 1:]
|
| 57 |
+
n = base_ranks.shape[1]
|
| 58 |
+
for k in k_vals:
|
| 59 |
+
base_sims_k = torch.gather(base_sims, 1, base_ranks[:, :k])
|
| 60 |
+
compressed_sims_k = torch.gather(compressed_sims, 1, base_ranks[:, :k])
|
| 61 |
+
outputs.append(compute_correlation(base_sims_k, compressed_sims_k, rm_diag=False))
|
| 62 |
+
|
| 63 |
+
return torch.stack(outputs).unsqueeze(0)
|
| 64 |
+
|
| 65 |
+
class FeedForward(nn.Module):
|
| 66 |
+
def __init__(self, d_in, d_out):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.fc1 = nn.Linear(d_in, d_out*2)
|
| 69 |
+
self.fc2 = nn.Linear(d_out, d_out)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.fc1(x)
|
| 73 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 74 |
+
x = self.fc2(F.silu(x1) * x2)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
class CompressionHead(nn.Module):
|
| 78 |
+
def __init__(self, d_in, d_out, dropout=0.1):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.ff = FeedForward(d_in, d_out)
|
| 81 |
+
self.skip = nn.Linear(d_in, d_out)
|
| 82 |
+
self.dropout = nn.Dropout(dropout)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
x = self.dropout(x)
|
| 86 |
+
x = self.ff(x) + self.skip(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class CompressionModelOutput(ModelOutput):
|
| 91 |
+
loss: Optional[torch.FloatTensor] = None
|
| 92 |
+
losses: Optional[List[torch.FloatTensor]] = None
|
| 93 |
+
base_embedding: Optional[torch.FloatTensor] = None
|
| 94 |
+
compressed_embeddings: Optional[List[torch.FloatTensor]] = None
|
| 95 |
+
|
| 96 |
+
class CompressionModel(PreTrainedModel):
|
| 97 |
+
config_class = CompressionConfig
|
| 98 |
+
def __init__(self, config):
|
| 99 |
+
super().__init__(config)
|
| 100 |
+
self.heads = nn.ModuleList([CompressionHead(config.input_size, i, config.dropout)
|
| 101 |
+
for i in config.compression_sizes])
|
| 102 |
+
|
| 103 |
+
def forward(self, embedding, compute_loss=True, return_dict=True):
|
| 104 |
+
outputs = []
|
| 105 |
+
losses = None
|
| 106 |
+
|
| 107 |
+
if compute_loss:
|
| 108 |
+
losses = []
|
| 109 |
+
emb_sims = cosine_pairwise(embedding)
|
| 110 |
+
|
| 111 |
+
for head in self.heads:
|
| 112 |
+
compressed_embedding = head(embedding)
|
| 113 |
+
outputs.append(compressed_embedding)
|
| 114 |
+
|
| 115 |
+
if compute_loss:
|
| 116 |
+
comp_sims = cosine_pairwise(compressed_embedding)
|
| 117 |
+
loss = loss_function(emb_sims, comp_sims, self.config.loss_k_vals)
|
| 118 |
+
losses.append(loss)
|
| 119 |
+
|
| 120 |
+
loss = torch.cat(losses).sum()
|
| 121 |
+
|
| 122 |
+
if not return_dict:
|
| 123 |
+
return (loss, losses, embedding, outputs)
|
| 124 |
+
|
| 125 |
+
return CompressionModelOutput(loss=loss,
|
| 126 |
+
losses=losses,
|
| 127 |
+
base_embedding=embedding,
|
| 128 |
+
compressed_embeddings=outputs)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|