dchen0 commited on
Commit
e8208a0
·
verified ·
1 Parent(s): 9d161f4

Add merged model + processor

Browse files
Files changed (5) hide show
  1. config.json +3 -6
  2. handler.py +63 -0
  3. model.safetensors +2 -2
  4. requirements.txt +2 -0
  5. train_model.py +347 -0
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "apply_layernorm": true,
3
  "architectures": [
4
- "FontClassifierWithPreprocessing"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "drop_path_rate": 0.0,
@@ -837,8 +837,5 @@
837
  "torch_dtype": "float32",
838
  "transformers_version": "4.52.4",
839
  "use_mask_token": true,
840
- "use_swiglu_ffn": false,
841
- "auto_map": {
842
- "AutoModelForImageClassification": "font_classifier_with_preprocessing.FontClassifierWithPreprocessing"
843
- }
844
- }
 
1
  {
2
  "apply_layernorm": true,
3
  "architectures": [
4
+ "Dinov2ForImageClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "drop_path_rate": 0.0,
 
837
  "torch_dtype": "float32",
838
  "transformers_version": "4.52.4",
839
  "use_mask_token": true,
840
+ "use_swiglu_ffn": false
841
+ }
 
 
 
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to be bundled with the model on upload to HF Inference Endpoints
2
+
3
+ import base64
4
+ import io
5
+ from typing import Any, Dict
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoImageProcessor, Dinov2ForImageClassification
10
+
11
+ from train_model import get_inference_transform
12
+
13
+
14
+ class EndpointHandler:
15
+ """
16
+ HF Inference Endpoints entry‑point.
17
+ Loads model/processor once, then uses your *imported* preprocessing
18
+ on every request.
19
+ """
20
+
21
+ def __init__(self, path: str = "", image_size: int = 224):
22
+ # Weights + processor --------------------------------------------------------
23
+ self.processor = AutoImageProcessor.from_pretrained(path or ".")
24
+ self.model = (
25
+ Dinov2ForImageClassification.from_pretrained(path or ".")
26
+ .eval()
27
+ )
28
+
29
+ # Re‑use the exact transform from your code ---------------------------------
30
+ self.transform = get_inference_transform(self.processor, image_size)
31
+
32
+ self.id2label = self.model.config.id2label
33
+
34
+ # -------------------------------------------------------------------------------
35
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
36
+ """
37
+ Expects {"inputs": "<base64‑encoded image>"}.
38
+ Returns the top prediction + per‑class probabilities.
39
+ """
40
+ if "inputs" not in data:
41
+ raise ValueError("Request JSON must contain an 'inputs' field.")
42
+
43
+ # Decode base64 → PIL
44
+ img_bytes = base64.b64decode(data["inputs"])
45
+ image = Image.open(io.BytesIO(img_bytes))
46
+
47
+ # Preprocess with your own transform
48
+ pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
49
+
50
+ with torch.no_grad():
51
+ logits = self.model(pixel_values).logits
52
+ probs = logits.softmax(dim=-1)[0]
53
+
54
+ top_idx = int(probs.argmax())
55
+ top_label = self.id2label[top_idx]
56
+
57
+ return {
58
+ "predicted_label": top_label,
59
+ "scores": {
60
+ self.id2label[i]: float(p)
61
+ for i, p in enumerate(probs)
62
+ }
63
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0eeabb74d1af47629e61d6d4dd48bbf3eb74121db29c8ba8b644b41b8c481a6d
3
- size 348770168
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecb73dab2fc1203ab36f177c7a6c5a23e472f5fff58b8ce5f8fc51f20f0480e1
3
+ size 348769976
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torchvision>=0.19
2
+ Pillow>=10
train_model.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms as T
12
+ from datasets import load_dataset
13
+ from huggingface_hub import HfApi
14
+ from peft import LoraConfig, PeftModel, get_peft_model
15
+ from PIL import Image
16
+ from safetensors import safe_open
17
+ from transformers import (
18
+ AutoImageProcessor,
19
+ Dinov2ForImageClassification,
20
+ Trainer,
21
+ TrainingArguments,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ MODEL = "facebook/dinov2-base-imagenet1k-1-layer"
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(description='Train a DINOv2 model for font classification')
30
+ parser.add_argument('--data_dir', type=str, default=None,
31
+ help='Directory containing the font dataset')
32
+ parser.add_argument('--output_dir', type=str, default=None,
33
+ help='Directory to save the model')
34
+ parser.add_argument('--checkpoint', type=str, default=None,
35
+ help='Path to checkpoint to resume training from')
36
+ parser.add_argument('--batch_size', type=int, default=32,
37
+ help='Training and evaluation batch size')
38
+ parser.add_argument('--epochs', type=int, default=1,
39
+ help='Number of training epochs')
40
+ parser.add_argument('--learning_rate', type=float, default=1e-4,
41
+ help='Learning rate for training')
42
+ parser.add_argument('--lora_rank', type=int, default=8,
43
+ help='LoRA rank for parameter-efficient fine-tuning')
44
+ parser.add_argument('--lora_alpha', type=int, default=16,
45
+ help='LoRA alpha parameter')
46
+ parser.add_argument('--lora_dropout', type=float, default=0.1,
47
+ help='LoRA dropout rate')
48
+ parser.add_argument('--test_size', type=float, default=0.1,
49
+ help='Proportion of data to use for validation')
50
+ parser.add_argument('--seed', type=int, default=42,
51
+ help='Random seed for reproducibility')
52
+ parser.add_argument('--log_level', type=str, default='INFO',
53
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
54
+ help='Logging level')
55
+ parser.add_argument('--huggingface_model_name', type=str, default=None,
56
+ help='Name of the model to push to the Hub')
57
+ return parser.parse_args()
58
+
59
+
60
+ def load_checkpoint_with_size_mismatch_handling(base_model, checkpoint_path, peft_config):
61
+ """
62
+ Load PEFT checkpoint with automatic handling of size mismatches.
63
+ This uses PEFT's built-in loading but with strict=False to handle size mismatches gracefully.
64
+
65
+ Basically, if we have a different number of labels than in the checkpoint, we re-initialize the classifier head to relearn it.
66
+
67
+ Args:
68
+ base_model: The base model with the new classifier size
69
+ checkpoint_path: Path to the checkpoint
70
+ peft_config: LoraConfig object with the desired configuration
71
+
72
+ Returns:
73
+ PeftModel with loaded weights (mismatched layers will be skipped)
74
+ """
75
+ logger.info(f"Loading checkpoint with automatic size mismatch handling: {checkpoint_path}")
76
+
77
+ try:
78
+ # Try the normal PEFT loading first
79
+ model = PeftModel.from_pretrained(
80
+ base_model,
81
+ checkpoint_path,
82
+ is_trainable=True
83
+ )
84
+ logger.info("Successfully loaded checkpoint without size mismatches")
85
+ return model
86
+ except Exception as e:
87
+ logger.info(f"Standard loading failed ({str(e)}), using fallback loading method")
88
+
89
+ # Fallback: Create fresh PEFT model and load compatible weights
90
+ # Note: PeftModel.from_pretrained might have partially modified base_model before failing,
91
+ # so we recreate a clean base model to avoid double-loading warnings
92
+ fresh_base = Dinov2ForImageClassification.from_pretrained(
93
+ MODEL,
94
+ num_labels=base_model.config.num_labels,
95
+ ignore_mismatched_sizes=True,
96
+ )
97
+
98
+ model = get_peft_model(fresh_base, peft_config)
99
+
100
+ # Load checkpoint state dict
101
+ checkpoint_file = os.path.join(checkpoint_path, "adapter_model.safetensors")
102
+
103
+ if not os.path.exists(checkpoint_file):
104
+ raise ValueError(f"Checkpoint file {checkpoint_file} does not exist")
105
+
106
+ checkpoint_state_dict = {}
107
+ with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
108
+ for key in f.keys():
109
+ checkpoint_state_dict[key] = f.get_tensor(key)
110
+
111
+ # Load only compatible weights
112
+ missing_keys, unexpected_keys = model.load_state_dict(checkpoint_state_dict, strict=False)
113
+
114
+ logger.info(f"Loaded checkpoint with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
115
+ logger.info(f"The following keys were in the checkpoint but are now missing: {missing_keys}")
116
+ logger.info(f"The following keys are new i.e. unexpected: {unexpected_keys}")
117
+ logger.info("Missing keys (likely new classifier parameters): will be randomly initialized")
118
+
119
+ return model
120
+
121
+ def get_inference_transform(processor: AutoImageProcessor, size: int):
122
+ """Get the raw validation transform for direct inference on PIL images."""
123
+ normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std)
124
+
125
+ to_rgb = T.Lambda(lambda img: img.convert('RGB'))
126
+
127
+ def pad_to_square(img):
128
+ w, h = img.size
129
+ max_size = max(w, h)
130
+ pad_w = (max_size - w) // 2
131
+ pad_h = (max_size - h) // 2
132
+ padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
133
+ return T.Pad(padding, fill=0)(img)
134
+
135
+ aug = T.Compose([
136
+ to_rgb,
137
+ pad_to_square,
138
+ T.Resize(size),
139
+ T.ToTensor(),
140
+ normalize
141
+ ])
142
+
143
+ return aug
144
+
145
+
146
+
147
+ def get_transform(processor: AutoImageProcessor, size: int):
148
+ aug = get_inference_transform(processor, size)
149
+
150
+ def transform(example, train=True):
151
+ # The dataset uses 'image' as the key for PIL images
152
+ # Use the processor directly - it handles pad_to_square + standard preprocessing
153
+ inputs = processor(images=example["image"], return_tensors="pt")
154
+ example["pixel_values"] = inputs["pixel_values"].squeeze(0) # Remove batch dimension for dataset
155
+ return example
156
+
157
+ return transform
158
+
159
+
160
+ if __name__ == "__main__":
161
+ args = parse_args()
162
+
163
+ # Configure logging with timestamps
164
+ logging.basicConfig(
165
+ level=args.log_level,
166
+ format='%(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d',
167
+ datefmt='%Y-%m-%d %H:%M:%S'
168
+ )
169
+
170
+ ######################################################################
171
+ # Directory layout expected by ImageFolder:
172
+ # fonts/
173
+ # ├─ Arial/
174
+ # │ ├─ img001.png
175
+ # │ └─ ...
176
+ # ├─ TimesNewRoman/
177
+ # └─ ...
178
+
179
+ logger.info(f"Loading dataset from {args.data_dir}")
180
+ # Get label names from directory names and sort them alphabetically
181
+ # to match the order used by the imagefolder dataset loader
182
+ label_names = sorted(os.listdir(f"{args.data_dir}/train"))
183
+ logger.info(f"Found {len(label_names)} labels")
184
+
185
+ if len(label_names) <= 1:
186
+ raise ValueError(f"Expected at least 2 labels, got {label_names=}, imagefolder will not label the dataset if there are less than 2 labels.")
187
+
188
+ # READ: the label ids assigned are in alphabetical order.
189
+ train_dataset = None
190
+ test_dataset = None
191
+
192
+
193
+ logger.info("Setting up image processor and augmentations")
194
+ processor = AutoImageProcessor.from_pretrained(MODEL) # 224 px
195
+ size = processor.size["shortest_edge"] # 224 by default
196
+
197
+ if args.epochs > 0:
198
+ dataset = load_dataset(
199
+ "imagefolder",
200
+ data_dir=args.data_dir,
201
+ )
202
+
203
+ logger.info(f"Train size: {len(dataset['train'])}, Validation size: {len(dataset['test'])}")
204
+
205
+ transform = get_transform(processor, size)
206
+
207
+ logger.info("Applying data transformations")
208
+ train_dataset = dataset["train"].map(
209
+ lambda x: transform(x, train=True),
210
+ remove_columns=["image"],
211
+ desc="Transforming training data"
212
+ )
213
+ test_dataset = dataset["test"].map(
214
+ lambda x: transform(x, train=False),
215
+ remove_columns=["image"],
216
+ desc="Transforming test data"
217
+ )
218
+
219
+ # Set the format to torch tensors
220
+ train_dataset.set_format(type="torch", columns=["pixel_values", "label"])
221
+ test_dataset.set_format(type="torch", columns=["pixel_values", "label"])
222
+
223
+ logger.info("Data preprocessing complete")
224
+
225
+ logger.info("Loading DINOv2 model")
226
+
227
+ base = Dinov2ForImageClassification.from_pretrained(
228
+ MODEL,
229
+ num_labels=len(label_names),
230
+ ignore_mismatched_sizes=True,
231
+ )
232
+
233
+ logger.info("Configuring LoRA adapters")
234
+ peft_cfg = LoraConfig(
235
+ r = args.lora_rank,
236
+ lora_alpha = args.lora_alpha,
237
+ target_modules = ["query", "value"], # Q & V proj in ViT blocks
238
+ lora_dropout = args.lora_dropout,
239
+ bias = "none",
240
+ modules_to_save = ["classifier"], # IMPORTANT: Save classification head too!
241
+ )
242
+
243
+ if args.checkpoint:
244
+ model = load_checkpoint_with_size_mismatch_handling(base, args.checkpoint, peft_cfg)
245
+ else:
246
+ model = get_peft_model(base, peft_cfg) # fresh LoRA wrap
247
+
248
+ model.print_trainable_parameters()
249
+
250
+ def collate(batch):
251
+ # The transform function has already converted images to tensors and stored them in pixel_values
252
+ pixel_values = torch.stack([item["pixel_values"] for item in batch])
253
+ labels = torch.tensor([item["label"] for item in batch])
254
+ return {"pixel_values": pixel_values, "labels": labels}
255
+
256
+ # Add compute_metrics function for accuracy calculation
257
+ def compute_metrics(eval_pred):
258
+ predictions, labels = eval_pred
259
+ predictions = predictions.argmax(axis=-1)
260
+ accuracy = (predictions == labels).mean()
261
+ return {"accuracy": accuracy}
262
+
263
+ logger.info("Setting up training arguments")
264
+ # Check if we're on MPS (Apple Silicon)
265
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
266
+ logger.info(f"Using device: {device}")
267
+
268
+ training_args = TrainingArguments(
269
+ output_dir = args.output_dir,
270
+ per_device_train_batch_size = args.batch_size,
271
+ per_device_eval_batch_size = args.batch_size,
272
+ # Tell Trainer which key in each batch holds the ground‑truth labels.
273
+ # Without it (especially with PEFT/LoRA wrappers), Trainer thinks there
274
+ # are no labels, skips compute_metrics, and never logs eval_accuracy.
275
+ label_names=["labels"],
276
+ eval_strategy = "steps" if args.epochs > 0 else "no",
277
+ eval_steps = 500,
278
+ save_strategy = "steps" if args.epochs > 0 else "no",
279
+ save_steps = 500,
280
+ num_train_epochs = args.epochs,
281
+ learning_rate = args.learning_rate,
282
+ weight_decay = 0.05,
283
+ fp16 = device.type == "cuda",
284
+ save_total_limit = 3,
285
+ logging_dir = os.path.join(args.output_dir, "logs") if args.output_dir else None,
286
+ logging_steps = 10,
287
+ report_to = "tensorboard",
288
+ load_best_model_at_end = True,
289
+ metric_for_best_model = "eval_accuracy",
290
+ greater_is_better = True,
291
+ # Pass the actual checkpoint path for proper resumption
292
+ resume_from_checkpoint = args.checkpoint if args.checkpoint else None,
293
+ )
294
+
295
+ trainer = Trainer(
296
+ model = model,
297
+ args = training_args,
298
+ train_dataset = train_dataset,
299
+ eval_dataset = test_dataset,
300
+ data_collator = collate,
301
+ compute_metrics = compute_metrics,
302
+ )
303
+
304
+ logger.info("Starting training")
305
+ if args.checkpoint:
306
+ logger.info(f"Resuming training from checkpoint: {args.checkpoint}")
307
+
308
+ if args.epochs > 0:
309
+ trainer.train()
310
+ logger.info("Training complete")
311
+
312
+ # Saves the result model to the output directory
313
+ # The reason this is important is if we configure load_best_model_at_end=True,
314
+ # the best model will be saved out of all checkpoints.
315
+ # So, even though the trainer already saves the last model as a checkpoint, that one is not necessarily the best.
316
+ if args.output_dir:
317
+ logger.info("Saving result model to the output directory")
318
+ trainer.save_model(f"{args.output_dir}/result_model")
319
+
320
+ if args.huggingface_model_name:
321
+ logger.info(f"Pushing model to the Hub: {args.huggingface_model_name}")
322
+
323
+ trainer.hub_model_id = args.huggingface_model_name
324
+
325
+ with tempfile.TemporaryDirectory() as tmp:
326
+ # Merge the PEFT weights into the base model so that we upload an independent complete model.
327
+ merged = trainer.model.merge_and_unload()
328
+ id2label = {i: name for i, name in enumerate(label_names)}
329
+ label2id = {name: i for i, name in enumerate(label_names)}
330
+
331
+ merged.config.id2label = id2label
332
+ merged.config.label2id = label2id
333
+ merged.config.pipeline_tag = "image-classification"
334
+ merged.save_pretrained(tmp, safe_serialization=True)
335
+ processor.save_pretrained(tmp)
336
+
337
+ # bundle handler and code
338
+ shutil.copy("train_model.py", tmp)
339
+ shutil.copy("handler.py", tmp)
340
+ Path(tmp, "requirements.txt").write_text("torchvision>=0.19\nPillow>=10\n")
341
+
342
+ HfApi().upload_folder(
343
+ repo_id=args.huggingface_model_name,
344
+ folder_path=tmp,
345
+ commit_message="Add merged model + processor",
346
+ token=os.environ["HUGGINGFACE_API_KEY"],
347
+ )