LoRA-Finetuned DistilBERT for Sentiment Classification on the IMDB Dataset
- This repository provides a LoRA adapter model fine-tuned for sentiment classification on the IMDB movie review dataset.
- Fine-tuning was performed efficiently using LoRA (Low-Rank Adaptation) from the PEFT library, with training managed by the PyTorch Lightning framework.
- The LoRA adapter weights must be combined with the base DistilBERT model weights for inference.
You can use the following Python code to perform inference:
(The code is lengthy since the layer keys must be aligned correctly)
import torch
from safetensors.torch import load_file
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import get_peft_model, LoraConfig
from huggingface_hub import hf_hub_download
import os
# compare the base model and the adapter
base_model_name = "distilbert-base-uncased"
hf_repo_id = "Qndhm/distilled-bert-imdb-lora-adapter"
# --- load base model and adapter ---
print(f"loading base model: {base_model_name}")
model = AutoModelForSequenceClassification.from_pretrained(
base_model_name,
num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(hf_repo_id)
# --- create the config file
peft_config = LoraConfig(
task_type="SEQ_CLS",
r=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_lin", "v_lin"],
modules_to_save=["pre_classifier", "classifier"]#to be used later
)
# add a null LoRA adapter
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# load the adapter weights form hub
print(f"\n downloading weights from hub: {hf_repo_id}")
weights_path = hf_hub_download(repo_id=hf_repo_id, filename="adapter_model.safetensors")
adapter_weights = load_file(weights_path)
#compare the keys of adapter and base model
print("Hub keys of the adapter:", list(adapter_weights.keys()))
# print base model keys
model_trainable_keys = [k for k, v in model.named_parameters() if v.requires_grad]
print("base model keys:", model_trainable_keys)
new_state_dict = {}
for k, v in adapter_weights.items():
#adjust the keys to be consistent
new_key = k.replace(".weight", ".default.weight")
if "classifier" in new_key:
# ...classifier.bias -> ...classifier.modules_to_save.default.bias)
if new_key.endswith(".bias"):
new_key = new_key.replace(".bias", ".modules_to_save.default.bias")
# ...classifier.default.weight -> ...classifier.modules_to_save.default.weight)
elif new_key.endswith(".weight"):
new_key = new_key.replace(".default.weight", ".modules_to_save.default.weight")
new_state_dict[new_key] = v
print("New keys:", list(new_state_dict.keys()))
print("\n Load weights with new keys")
model.load_state_dict(new_state_dict, strict=False)
#Test the model from here
text_pos = "I do not like this movie, it was bad!"
inputs_pos = tokenizer(text_pos, return_tensors="pt")
with torch.no_grad():
outputs_pos = model(**inputs_pos)
predicted_class_id_pos = outputs_pos.logits.argmax().item()
print(f"positive: '{text_pos}' --> prediction: {predicted_class_id_pos}")
Model tree for Qndhm/distilled-bert-imdb-lora-adapter
Base model
distilbert/distilbert-base-uncased