LoRA-Finetuned DistilBERT for Sentiment Classification on the IMDB Dataset

  • This repository provides a LoRA adapter model fine-tuned for sentiment classification on the IMDB movie review dataset.
  • Fine-tuning was performed efficiently using LoRA (Low-Rank Adaptation) from the PEFT library, with training managed by the PyTorch Lightning framework.
  • The LoRA adapter weights must be combined with the base DistilBERT model weights for inference.

You can use the following Python code to perform inference:

(The code is lengthy since the layer keys must be aligned correctly)

import torch
from safetensors.torch import load_file
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import get_peft_model, LoraConfig
from huggingface_hub import hf_hub_download
import os

# compare the base model and the adapter
base_model_name = "distilbert-base-uncased"
hf_repo_id = "Qndhm/distilled-bert-imdb-lora-adapter"

# --- load base model and adapter ---
print(f"loading base model: {base_model_name}")
model = AutoModelForSequenceClassification.from_pretrained(
    base_model_name,
    num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(hf_repo_id)

# --- create the config file
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
    modules_to_save=["pre_classifier", "classifier"]#to be used later
)

# add a null LoRA adapter
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# load the adapter weights form hub
print(f"\n downloading weights from hub: {hf_repo_id}")
weights_path = hf_hub_download(repo_id=hf_repo_id, filename="adapter_model.safetensors")
adapter_weights = load_file(weights_path)

#compare the keys of adapter and base model
print("Hub keys of the adapter:", list(adapter_weights.keys()))

# print base model keys
model_trainable_keys = [k for k, v in model.named_parameters() if v.requires_grad]
print("base model keys:", model_trainable_keys)
new_state_dict = {}
for k, v in adapter_weights.items():
    #adjust the keys to be consistent
    new_key = k.replace(".weight", ".default.weight")
    if "classifier" in new_key:
      # ...classifier.bias -> ...classifier.modules_to_save.default.bias)
      if new_key.endswith(".bias"):
        new_key = new_key.replace(".bias", ".modules_to_save.default.bias")
        # ...classifier.default.weight -> ...classifier.modules_to_save.default.weight)
      elif new_key.endswith(".weight"):
        new_key = new_key.replace(".default.weight", ".modules_to_save.default.weight")
    new_state_dict[new_key] = v

print("New keys:", list(new_state_dict.keys()))
print("\n Load weights with new keys")
model.load_state_dict(new_state_dict, strict=False)

#Test the model from here
text_pos = "I do not like this movie, it was bad!"
inputs_pos = tokenizer(text_pos, return_tensors="pt")
with torch.no_grad():
    outputs_pos = model(**inputs_pos)
predicted_class_id_pos = outputs_pos.logits.argmax().item()
print(f"positive: '{text_pos}' --> prediction: {predicted_class_id_pos}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Qndhm/distilled-bert-imdb-lora-adapter

Finetuned
(10774)
this model