| | --- |
| | license: mit |
| | base_model: gpt2-large |
| | tags: |
| | - natural-language-inference |
| | - lora |
| | - peft |
| | - gpt2 |
| | - multinli |
| | - text-classification |
| | datasets: |
| | - nyu-mll/multi_nli |
| | language: |
| | - en |
| | pipeline_tag: text-classification |
| | --- |
| | |
| | # GPT2-Large LoRA Fine-tuned for Natural Language Inference |
| |
|
| | This model is a LoRA (Low-Rank Adaptation) fine-tuned version of GPT2-large for Natural Language Inference (NLI) on the MultiNLI dataset. |
| |
|
| | ## Model Details |
| |
|
| | - **Base Model**: GPT2-large (774M parameters) |
| | - **Fine-tuning Method**: LoRA (Low-Rank Adaptation) |
| | - **Trainable Parameters**: ~2.3M (0.3% of total parameters) |
| | - **Dataset**: MultiNLI (50K training samples) |
| | - **Task**: Natural Language Inference (3-class classification) |
| |
|
| | ## Performance |
| |
|
| | - **Test Accuracy (Matched)**: ~79.22% |
| | - **Test Accuracy (Mismatched)**: ~80.38% |
| | - **Training Method**: Parameter-efficient fine-tuning with LoRA |
| | - **Hardware**: Trained on 36G vGPU |
| |
|
| | ## Usage |
| |
|
| | ```python |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import PeftModel |
| | |
| | # Load tokenizer and base model |
| | tokenizer = AutoTokenizer.from_pretrained("gpt2-large") |
| | base_model = AutoModelForCausalLM.from_pretrained("gpt2-large") |
| | |
| | # Load LoRA adapter |
| | model = PeftModel.from_pretrained(base_model, "hilaryc112/LoRA-GPT2-Project") |
| | |
| | # Format input |
| | premise = "A person is outdoors, on a horse." |
| | hypothesis = "A person is at a diner, ordering an omelette." |
| | input_text = f"Premise: {premise}\nHypothesis: {hypothesis}\nRelationship:" |
| | |
| | # Tokenize and generate |
| | inputs = tokenizer(input_text, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = model.generate(**inputs, max_new_tokens=10, pad_token_id=tokenizer.eos_token_id) |
| | |
| | prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
| | print(f"Prediction: {prediction}") # Should output: contradiction, neutral, or entailment |
| | ``` |
| |
|
| | ## Training Configuration |
| |
|
| | { |
| | "model_name": "gpt2-large", |
| | "max_length": 512, |
| | "lora_r": 16, |
| | "lora_alpha": 32, |
| | "lora_dropout": 0.1, |
| | "target_modules": [ |
| | "c_attn", |
| | "c_proj", |
| | "c_fc" |
| | ], |
| | "num_epochs": 3, |
| | "train_batch_size": 4, |
| | "eval_batch_size": 12, |
| | "gradient_accumulation_steps": 6, |
| | "learning_rate": 0.0002, |
| | "weight_decay": 0.01, |
| | "max_grad_norm": 1.0, |
| | "use_fp16": true, |
| | "gradient_checkpointing": true, |
| | "logging_steps": 100, |
| | "eval_steps": 500, |
| | "save_steps": 500, |
| | "save_total_limit": 3, |
| | "early_stopping_patience": 5, |
| | "data_dir": "./processed_data", |
| | "output_dir": "./gpt2_lora_multinli", |
| | "seed": 42, |
| | "use_wandb": false, |
| | "_comments": { |
| | "effective_batch_size": "6 * 6 = 36 (optimized for 36G vGPU)", |
| | "memory_optimization": "FP16 + gradient checkpointing enabled", |
| | "lora_config": "Rank 16 with alpha 32 for good performance/efficiency balance", |
| | "target_modules": "GPT2 attention and MLP layers for comprehensive adaptation", |
| | "training_data": "Uses 50K samples from MultiNLI training set (configured in preprocessing)", |
| | "evaluation_data": "Uses local dev files for matched/mismatched evaluation", |
| | "training_adjustments": "Reduced epochs to 2 and LR to 1e-4 for better training with real data", |
| | "eval_frequency": "Less frequent evaluation (every 500 steps) due to larger dataset" |
| | } |
| | } |
| | |
| | ## Dataset Format |
| |
|
| | The model was trained on text-to-text format: |
| | ``` |
| | Premise: [premise text] |
| | Hypothesis: [hypothesis text] |
| | Relationship: [entailment/neutral/contradiction] |
| | ``` |
| |
|
| | ## Files |
| |
|
| | - `adapter_config.json`: LoRA adapter configuration |
| | - `adapter_model.safetensors`: LoRA adapter weights |
| | - `training_config.json`: Training hyperparameters and settings |
| |
|
| | ## Citation |
| |
|
| | If you use this model, please cite: |
| |
|
| | ```bibtex |
| | @misc{gpt2-lora-multinli, |
| | title={GPT2-Large LoRA Fine-tuned for Natural Language Inference}, |
| | author={HilaryKChen}, |
| | year={2024}, |
| | howpublished={\url{https://huggingface.co/hilaryc112/LoRA-GPT2-Project}} |
| | } |
| | ``` |
| |
|
| | ## License |
| |
|
| | This model is released under the MIT License. |
| |
|