Sairam22 commited on
Commit
2519bf0
·
verified ·
1 Parent(s): 542283d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -20
README.md CHANGED
@@ -9,31 +9,69 @@ tags:
9
  model-index:
10
  - name: matcha-chartqa-lora-adapter
11
  results: []
 
 
12
  ---
13
 
14
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
  should probably proofread and complete it, then remove this comment. -->
16
 
17
- # matcha-chartqa-lora-adapter
18
-
19
- This model is a fine-tuned version of [google/matcha-base](https://huggingface.co/google/matcha-base) on an unknown dataset.
20
-
21
- ## Model description
22
-
23
- More information needed
24
-
25
- ## Intended uses & limitations
26
-
27
- More information needed
28
-
29
- ## Training and evaluation data
30
-
31
- More information needed
32
-
33
- ## Training procedure
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ### Training hyperparameters
36
-
37
  The following hyperparameters were used during training:
38
  - learning_rate: 0.0002
39
  - train_batch_size: 2
@@ -46,8 +84,20 @@ The following hyperparameters were used during training:
46
  - training_steps: 50
47
  - mixed_precision_training: Native AMP
48
 
49
- ### Training results
 
 
 
 
 
 
 
 
 
 
 
50
 
 
51
 
52
 
53
  ### Framework versions
 
9
  model-index:
10
  - name: matcha-chartqa-lora-adapter
11
  results: []
12
+ datasets:
13
+ - HuggingFaceM4/ChartQA
14
  ---
15
 
16
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
17
  should probably proofread and complete it, then remove this comment. -->
18
 
19
+ # Multimodal SLM Fine-Tuning: ChartQA with MatCha
20
+
21
+ This repository contains the code and documentation for fine-tuning a Small Language Model (SLM) on the ChartQA dataset using Parameter-Efficient Fine-Tuning (LoRA).
22
+
23
+ This project was developed to demonstrate a complete multimodal fine-tuning pipeline capable of running on a single NVIDIA T4 GPU (16GB VRAM).
24
+
25
+ ## 🚀 How to Run Inference
26
+
27
+ The following standalone code snippet demonstrates how to pull the fine-tuned LoRA adapters from Hugging Face, merge them with the base model, and run inference on a custom chart image.
28
+
29
+ ```python
30
+ import torch
31
+ from transformers import AutoProcessor, AutoModelForImageTextToText
32
+ from peft import PeftModel
33
+ from PIL import Image
34
+
35
+ # 1. Define Model IDs and Device
36
+ base_model_id = "google/matcha-base"
37
+ adapter_id = "Sairam22/matcha-chartqa-lora-adapter" # Replace if your repo name is different
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ # 2. Load Base Model and Processor
41
+ print("Loading base model and processor...")
42
+ processor = AutoProcessor.from_pretrained(adapter_id)
43
+ base_model = AutoModelForImageTextToText.from_pretrained(
44
+ base_model_id,
45
+ torch_dtype=torch.float16,
46
+ low_cpu_mem_usage=True
47
+ )
48
+
49
+ # 3. Pull Adapter from Hugging Face and Merge
50
+ print("Pulling adapter and merging weights...")
51
+ model = PeftModel.from_pretrained(base_model, adapter_id)
52
+ model = model.merge_and_unload()
53
+ model = model.to(device)
54
+
55
+ # 4. Prepare Image and Prompt
56
+ image_path = "path_to_your_chart.png" # Provide the path to a local chart image
57
+ image = Image.open(image_path).convert("RGB")
58
+ prompt = "Question: What is the highest value in the bar chart?\nAnswer:"
59
+
60
+ # 5. Process Inputs and Cast Dtypes
61
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
62
+ # Ensure float32 tensors (like images) are cast to float16 to match the model weights
63
+ inputs = {k: v.to(device, dtype=torch.float16) if v.dtype == torch.float32 else v.to(device) for k, v in inputs.items()}
64
+
65
+ # 6. Run Inference
66
+ print("Generating prediction...")
67
+ with torch.no_grad():
68
+ outputs = model.generate(**inputs, max_new_tokens=32)
69
+
70
+ prediction = processor.decode(outputs[0], skip_special_tokens=True)
71
+ print(f"Prediction: {prediction}")
72
+
73
+ ```
74
  ### Training hyperparameters
 
75
  The following hyperparameters were used during training:
76
  - learning_rate: 0.0002
77
  - train_batch_size: 2
 
84
  - training_steps: 50
85
  - mixed_precision_training: Native AMP
86
 
87
+ 🧠 Decision Log & T4 Optimizations
88
+ To ensure this pipeline runs efficiently on a single NVIDIA T4 GPU (16GB VRAM) and within strict time limits, several specific parameter choices were made:
89
+
90
+ Model Selection (google/matcha-base): Chosen because it is pre-trained specifically for chart visual language tasks (based on Pix2Struct) and is highly lightweight (~256M parameters), fitting easily into T4 memory.
91
+
92
+ Precision (fp16=True): Casting the base model to float16 cuts memory consumption in half, prevents datatype mismatch errors (c10::Half), and leverages the T4's Tensor Cores to speed up training.
93
+
94
+ LoRA Configuration (r=8, alpha=16): A rank of 8 introduces less than 5% trainable parameters. This prevents Out-Of-Memory (OOM) errors during training since the optimizer states are kept minimal. MatCha's specific attention layers (query, value) were explicitly targeted.
95
+
96
+ Batch Sizing (batch_size=2, gradient_accumulation=4): A physical batch size of 2 ensures VRAM limits aren't breached during the forward/backward pass, while gradient accumulation simulates an effective batch size of 8 for stable loss convergence.
97
+
98
+ Adapter Merging (merge_and_unload()): Fulfills the assignment requirement while also improving inference speed by flattening the adapter weights directly into the base model matrices, completely removing dynamic routing overhead during generation.
99
 
100
+ Training Subset: Due to hardware compute constraints and time limitations, a rapid subset of 100 samples was used for the final epoch to successfully validate the end-to-end pipeline, adapter uploading, and inference mechanisms.
101
 
102
 
103
  ### Framework versions