ducanhdinh/jepa_proof_vicreg_replace

BERT encoder pretrained from scratch với VICReg + Lexical Substitution augmentation.

Augmentation strategy

Thay vì span masking, model này dùng lexical substitution để tạo view 2:

Mô tả
View 1 Câu gốc (không thay đổi)
View 2 15–20% token ngẫu nhiên được thay bằng token ngữ nghĩa gần, dự đoán bởi BERT MLM (top-5, loại trừ token gốc)

Quá trình tạo view 2 được thực hiện offline 1 lần trước khi train, không gọi BERT thêm lần nào trong vòng lặp training.

Kiến trúc VICReg

Text → BERT (mean-pool) → z ∈ R^768 → Expander MLP → z' ∈ R^3072
                                                       ↑ VICReg loss áp dụng tại đây
Loss term Hệ số Mô tả
Invariance 25.0 MSE giữa z1 và z2 (căn chỉnh hai views)
Variance 25.0 Giữ std của mỗi chiều ≥ 1 (chống collapse)
Covariance 1.0 Decorrelate các chiều embedding

Expander gồm 3 lớp Linear-BatchNorm-ReLU (dim = 3072).

Thông số huấn luyện

Tham số Giá trị
Max sequence length 256
Batch size 256
Epochs 10
Learning rate 0.0001
Expander dim 3072
Mask ratio (lexsubst) 0.15–0.2
Top-k candidates 5
sim_coeff 25.0
std_coeff 25.0
cov_coeff 1.0

Cách dùng — BERT encoder (feature extraction)

from transformers import BertModel, BertTokenizerFast
import torch

tokenizer = BertTokenizerFast.from_pretrained("ducanhdinh/jepa_proof_vicreg_replace")
bert      = BertModel.from_pretrained("ducanhdinh/jepa_proof_vicreg_replace/encoder")

encoded = tokenizer(
    ["Hello world!", "VICReg with lexical substitution."],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
with torch.no_grad():
    out    = bert(**encoded)
    hidden = out.last_hidden_state          # (B, T, 768)
    mask   = encoded["attention_mask"].unsqueeze(-1).float()
    emb    = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1)  # mean-pool → (B, 768)

Cách dùng — Full model (encoder + expander)

import torch
from text_vicreg_replace import TextVICRegReplace, VICRegReplacePretrainConfig

cfg   = VICRegReplacePretrainConfig()
model = TextVICRegReplace(cfg)
state = torch.load(
    hf_hub_download("ducanhdinh/jepa_proof_vicreg_replace", "pytorch_model.bin"),
    map_location="cpu",
)
model.load_state_dict(state)
model.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support