Update README.md
Browse files
README.md
CHANGED
|
@@ -9,31 +9,69 @@ tags:
|
|
| 9 |
model-index:
|
| 10 |
- name: matcha-chartqa-lora-adapter
|
| 11 |
results: []
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 15 |
should probably proofread and complete it, then remove this comment. -->
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
This
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
#
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
### Training hyperparameters
|
| 36 |
-
|
| 37 |
The following hyperparameters were used during training:
|
| 38 |
- learning_rate: 0.0002
|
| 39 |
- train_batch_size: 2
|
|
@@ -46,8 +84,20 @@ The following hyperparameters were used during training:
|
|
| 46 |
- training_steps: 50
|
| 47 |
- mixed_precision_training: Native AMP
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
### Framework versions
|
|
|
|
| 9 |
model-index:
|
| 10 |
- name: matcha-chartqa-lora-adapter
|
| 11 |
results: []
|
| 12 |
+
datasets:
|
| 13 |
+
- HuggingFaceM4/ChartQA
|
| 14 |
---
|
| 15 |
|
| 16 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
| 17 |
should probably proofread and complete it, then remove this comment. -->
|
| 18 |
|
| 19 |
+
# Multimodal SLM Fine-Tuning: ChartQA with MatCha
|
| 20 |
+
|
| 21 |
+
This repository contains the code and documentation for fine-tuning a Small Language Model (SLM) on the ChartQA dataset using Parameter-Efficient Fine-Tuning (LoRA).
|
| 22 |
+
|
| 23 |
+
This project was developed to demonstrate a complete multimodal fine-tuning pipeline capable of running on a single NVIDIA T4 GPU (16GB VRAM).
|
| 24 |
+
|
| 25 |
+
## 🚀 How to Run Inference
|
| 26 |
+
|
| 27 |
+
The following standalone code snippet demonstrates how to pull the fine-tuned LoRA adapters from Hugging Face, merge them with the base model, and run inference on a custom chart image.
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 32 |
+
from peft import PeftModel
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
# 1. Define Model IDs and Device
|
| 36 |
+
base_model_id = "google/matcha-base"
|
| 37 |
+
adapter_id = "Sairam22/matcha-chartqa-lora-adapter" # Replace if your repo name is different
|
| 38 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
|
| 40 |
+
# 2. Load Base Model and Processor
|
| 41 |
+
print("Loading base model and processor...")
|
| 42 |
+
processor = AutoProcessor.from_pretrained(adapter_id)
|
| 43 |
+
base_model = AutoModelForImageTextToText.from_pretrained(
|
| 44 |
+
base_model_id,
|
| 45 |
+
torch_dtype=torch.float16,
|
| 46 |
+
low_cpu_mem_usage=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 3. Pull Adapter from Hugging Face and Merge
|
| 50 |
+
print("Pulling adapter and merging weights...")
|
| 51 |
+
model = PeftModel.from_pretrained(base_model, adapter_id)
|
| 52 |
+
model = model.merge_and_unload()
|
| 53 |
+
model = model.to(device)
|
| 54 |
+
|
| 55 |
+
# 4. Prepare Image and Prompt
|
| 56 |
+
image_path = "path_to_your_chart.png" # Provide the path to a local chart image
|
| 57 |
+
image = Image.open(image_path).convert("RGB")
|
| 58 |
+
prompt = "Question: What is the highest value in the bar chart?\nAnswer:"
|
| 59 |
+
|
| 60 |
+
# 5. Process Inputs and Cast Dtypes
|
| 61 |
+
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
| 62 |
+
# Ensure float32 tensors (like images) are cast to float16 to match the model weights
|
| 63 |
+
inputs = {k: v.to(device, dtype=torch.float16) if v.dtype == torch.float32 else v.to(device) for k, v in inputs.items()}
|
| 64 |
+
|
| 65 |
+
# 6. Run Inference
|
| 66 |
+
print("Generating prediction...")
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
outputs = model.generate(**inputs, max_new_tokens=32)
|
| 69 |
+
|
| 70 |
+
prediction = processor.decode(outputs[0], skip_special_tokens=True)
|
| 71 |
+
print(f"Prediction: {prediction}")
|
| 72 |
+
|
| 73 |
+
```
|
| 74 |
### Training hyperparameters
|
|
|
|
| 75 |
The following hyperparameters were used during training:
|
| 76 |
- learning_rate: 0.0002
|
| 77 |
- train_batch_size: 2
|
|
|
|
| 84 |
- training_steps: 50
|
| 85 |
- mixed_precision_training: Native AMP
|
| 86 |
|
| 87 |
+
🧠 Decision Log & T4 Optimizations
|
| 88 |
+
To ensure this pipeline runs efficiently on a single NVIDIA T4 GPU (16GB VRAM) and within strict time limits, several specific parameter choices were made:
|
| 89 |
+
|
| 90 |
+
Model Selection (google/matcha-base): Chosen because it is pre-trained specifically for chart visual language tasks (based on Pix2Struct) and is highly lightweight (~256M parameters), fitting easily into T4 memory.
|
| 91 |
+
|
| 92 |
+
Precision (fp16=True): Casting the base model to float16 cuts memory consumption in half, prevents datatype mismatch errors (c10::Half), and leverages the T4's Tensor Cores to speed up training.
|
| 93 |
+
|
| 94 |
+
LoRA Configuration (r=8, alpha=16): A rank of 8 introduces less than 5% trainable parameters. This prevents Out-Of-Memory (OOM) errors during training since the optimizer states are kept minimal. MatCha's specific attention layers (query, value) were explicitly targeted.
|
| 95 |
+
|
| 96 |
+
Batch Sizing (batch_size=2, gradient_accumulation=4): A physical batch size of 2 ensures VRAM limits aren't breached during the forward/backward pass, while gradient accumulation simulates an effective batch size of 8 for stable loss convergence.
|
| 97 |
+
|
| 98 |
+
Adapter Merging (merge_and_unload()): Fulfills the assignment requirement while also improving inference speed by flattening the adapter weights directly into the base model matrices, completely removing dynamic routing overhead during generation.
|
| 99 |
|
| 100 |
+
Training Subset: Due to hardware compute constraints and time limitations, a rapid subset of 100 samples was used for the final epoch to successfully validate the end-to-end pipeline, adapter uploading, and inference mechanisms.
|
| 101 |
|
| 102 |
|
| 103 |
### Framework versions
|