LoRA-GPT2-Project / README.md
hilaryc112's picture
Update README.md
9c17903 verified
---
license: mit
base_model: gpt2-large
tags:
- natural-language-inference
- lora
- peft
- gpt2
- multinli
- text-classification
datasets:
- nyu-mll/multi_nli
language:
- en
pipeline_tag: text-classification
---
# GPT2-Large LoRA Fine-tuned for Natural Language Inference
This model is a LoRA (Low-Rank Adaptation) fine-tuned version of GPT2-large for Natural Language Inference (NLI) on the MultiNLI dataset.
## Model Details
- **Base Model**: GPT2-large (774M parameters)
- **Fine-tuning Method**: LoRA (Low-Rank Adaptation)
- **Trainable Parameters**: ~2.3M (0.3% of total parameters)
- **Dataset**: MultiNLI (50K training samples)
- **Task**: Natural Language Inference (3-class classification)
## Performance
- **Test Accuracy (Matched)**: ~79.22%
- **Test Accuracy (Mismatched)**: ~80.38%
- **Training Method**: Parameter-efficient fine-tuning with LoRA
- **Hardware**: Trained on 36G vGPU
## Usage
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
base_model = AutoModelForCausalLM.from_pretrained("gpt2-large")
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "hilaryc112/LoRA-GPT2-Project")
# Format input
premise = "A person is outdoors, on a horse."
hypothesis = "A person is at a diner, ordering an omelette."
input_text = f"Premise: {premise}\nHypothesis: {hypothesis}\nRelationship:"
# Tokenize and generate
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=10, pad_token_id=tokenizer.eos_token_id)
prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
print(f"Prediction: {prediction}") # Should output: contradiction, neutral, or entailment
```
## Training Configuration
{
"model_name": "gpt2-large",
"max_length": 512,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"target_modules": [
"c_attn",
"c_proj",
"c_fc"
],
"num_epochs": 3,
"train_batch_size": 4,
"eval_batch_size": 12,
"gradient_accumulation_steps": 6,
"learning_rate": 0.0002,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"use_fp16": true,
"gradient_checkpointing": true,
"logging_steps": 100,
"eval_steps": 500,
"save_steps": 500,
"save_total_limit": 3,
"early_stopping_patience": 5,
"data_dir": "./processed_data",
"output_dir": "./gpt2_lora_multinli",
"seed": 42,
"use_wandb": false,
"_comments": {
"effective_batch_size": "6 * 6 = 36 (optimized for 36G vGPU)",
"memory_optimization": "FP16 + gradient checkpointing enabled",
"lora_config": "Rank 16 with alpha 32 for good performance/efficiency balance",
"target_modules": "GPT2 attention and MLP layers for comprehensive adaptation",
"training_data": "Uses 50K samples from MultiNLI training set (configured in preprocessing)",
"evaluation_data": "Uses local dev files for matched/mismatched evaluation",
"training_adjustments": "Reduced epochs to 2 and LR to 1e-4 for better training with real data",
"eval_frequency": "Less frequent evaluation (every 500 steps) due to larger dataset"
}
}
## Dataset Format
The model was trained on text-to-text format:
```
Premise: [premise text]
Hypothesis: [hypothesis text]
Relationship: [entailment/neutral/contradiction]
```
## Files
- `adapter_config.json`: LoRA adapter configuration
- `adapter_model.safetensors`: LoRA adapter weights
- `training_config.json`: Training hyperparameters and settings
## Citation
If you use this model, please cite:
```bibtex
@misc{gpt2-lora-multinli,
title={GPT2-Large LoRA Fine-tuned for Natural Language Inference},
author={HilaryKChen},
year={2024},
howpublished={\url{https://huggingface.co/hilaryc112/LoRA-GPT2-Project}}
}
```
## License
This model is released under the MIT License.