dchen0 commited on
Commit
518728c
·
verified ·
1 Parent(s): e72ee82

Add model with preprocessing instructions and client wrapper

Browse files
Files changed (3) hide show
  1. README.md +104 -30
  2. font_classifier_client.py +193 -0
  3. preprocessor_config.json +2 -5
README.md CHANGED
@@ -1,41 +1,115 @@
 
1
  ---
2
- base_model:
3
- - facebook/dinov2-base
4
  pipeline_tag: image-classification
5
- license: mit
6
- language:
7
- - en
8
- library_name: transformers
9
- tags:
10
- - dinov2
11
  ---
12
 
13
- # dchen0/font-classifier
14
- Merged DINOv2‑base checkpoint with LoRA weights for font classification.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
17
- should probably proofread and complete it, then remove this comment. -->
 
18
 
 
19
 
20
- ## Training procedure
21
 
22
- ### Training hyperparameters
 
23
 
24
- The following hyperparameters were used during training:
25
- -learning_rate 1e-4
26
- -lora_rank 8
27
- -lora_alpha 16
28
- -lora_dropout 0.1
29
- - train_batch_size: 32
30
- - eval_batch_size: 32
31
- - seed: 42
32
- - optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
33
- - lr_scheduler_type: linear
34
 
35
- ### Framework versions
36
 
37
- - PEFT 0.15.2
38
- - Transformers 4.52.4
39
- - Pytorch 2.7.1
40
- - Datasets 3.6.0
41
- - Tokenizers 0.21.1
 
 
 
 
1
+
2
  ---
3
+ license: apache-2.0
 
4
  pipeline_tag: image-classification
 
 
 
 
 
 
5
  ---
6
 
7
+ # Font Classifier DINOv2
8
+
9
+ A fine-tuned DINOv2 model for font classification trained on Google Fonts.
10
+
11
+ ⚠️ **Critical: This model requires custom preprocessing for optimal accuracy.**
12
+
13
+ ## Performance
14
+ - **With correct preprocessing**: ~86% accuracy
15
+ - **Without preprocessing**: ~30% accuracy
16
+
17
+ ## Required Preprocessing
18
+
19
+ Images must be **padded to square** (preserving aspect ratio) before being resized to 224×224.
20
+
21
+ ### Option 1: Use our client wrapper (Recommended)
22
+
23
+ ```python
24
+ from font_classifier_client import FontClassifierClient
25
+
26
+ # For local usage
27
+ client = FontClassifierClient.from_local_model("dchen0/font-classifier-v4")
28
+ results = client.predict("your_image.png")
29
+
30
+ # For Inference Endpoints (automatically handles preprocessing)
31
+ client = FontClassifierClient.from_inference_endpoint("https://your-endpoint-url")
32
+ results = client.predict("your_image.png")
33
+ print(f"Predicted font: {results[0][0]} ({results[0][1]:.2%} confidence)")
34
+ ```
35
+
36
+ ### Option 2: Manual preprocessing
37
+
38
+ ```python
39
+ import torch
40
+ import torchvision.transforms as T
41
+ from PIL import Image
42
+ from transformers import pipeline
43
+
44
+ def pad_to_square(image):
45
+ w, h = image.size
46
+ max_size = max(w, h)
47
+ pad_w = (max_size - w) // 2
48
+ pad_h = (max_size - h) // 2
49
+ padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
50
+ return T.Pad(padding, fill=0)(image)
51
+
52
+ # Preprocess image
53
+ image = Image.open("your_image.png").convert('RGB')
54
+ image = pad_to_square(image)
55
+
56
+ # Use with pipeline
57
+ classifier = pipeline("image-classification", model="dchen0/font-classifier-v4")
58
+ results = classifier(image)
59
+ ```
60
+
61
+ ## Model Details
62
+
63
+ - **Base Model**: facebook/dinov2-base-imagenet1k-1-layer
64
+ - **Training**: LoRA fine-tuning on Google Fonts dataset
65
+ - **Labels**: 394 font families
66
+ - **Architecture**: Vision Transformer (ViT-B/14)
67
+
68
+ ## Training Details
69
+
70
+ The model was trained with images that were:
71
+ 1. **Padded to square** preserving aspect ratio
72
+ 2. Resized to 224×224
73
+ 3. Normalized with ImageNet statistics
74
+ 4. Various data augmentations applied
75
+
76
+ ## Usage with Inference Endpoints
77
+
78
+ When using HuggingFace Inference Endpoints:
79
+
80
+ 1. **Deploy the model** to an Inference Endpoint
81
+ 2. **Use the client wrapper** which automatically handles preprocessing:
82
+
83
+ ```python
84
+ import requests
85
+ from font_classifier_client import FontClassifierClient
86
+
87
+ # The client handles all preprocessing automatically
88
+ client = FontClassifierClient.from_inference_endpoint(
89
+ api_url="https://your-endpoint.com",
90
+ api_token="your-token" # if required
91
+ )
92
 
93
+ results = client.predict("test_image.png")
94
+ print(f"Top prediction: {results[0][0]} ({results[0][1]:.2%})")
95
+ ```
96
 
97
+ The client wrapper ensures that images are properly padded to square before being sent to the endpoint.
98
 
99
+ ## Files
100
 
101
+ - `font_classifier_client.py`: Client wrapper with preprocessing
102
+ - Standard HuggingFace model files
103
 
104
+ ## Citation
 
 
 
 
 
 
 
 
 
105
 
106
+ If you use this model, please cite:
107
 
108
+ ```
109
+ @misc{font-classifier-dinov2,
110
+ title={Font Classifier DINOv2},
111
+ author={Your Name},
112
+ year={2024},
113
+ url={https://huggingface.co/dchen0/font-classifier-v4}
114
+ }
115
+ ```
font_classifier_client.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client-side wrapper for font classification with proper preprocessing.
3
+ Works with both local models and HuggingFace Inference Endpoints.
4
+ """
5
+ import base64
6
+ import io
7
+
8
+ import numpy as np
9
+ import requests
10
+ import torch
11
+ import torchvision.transforms as T
12
+ from PIL import Image
13
+ from transformers import AutoImageProcessor, Dinov2ForImageClassification
14
+
15
+
16
+ def pad_to_square(image):
17
+ """
18
+ Pad image to square while preserving aspect ratio.
19
+ This is the crucial preprocessing step for font classification.
20
+ """
21
+ if isinstance(image, torch.Tensor):
22
+ # Convert tensor to PIL for processing
23
+ if image.dim() == 4: # Batch dimension
24
+ image = image.squeeze(0)
25
+ image = T.ToPILImage()(image)
26
+
27
+ if isinstance(image, np.ndarray):
28
+ image = Image.fromarray(image)
29
+
30
+ if not isinstance(image, Image.Image):
31
+ raise ValueError(f"Expected PIL Image, got {type(image)}")
32
+
33
+ w, h = image.size
34
+ max_size = max(w, h)
35
+ pad_w = (max_size - w) // 2
36
+ pad_h = (max_size - h) // 2
37
+ padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
38
+ return T.Pad(padding, fill=0)(image)
39
+
40
+ class FontClassifierClient:
41
+ """
42
+ Client for font classification that ensures correct preprocessing.
43
+ Works with both local models and Inference Endpoints.
44
+ """
45
+
46
+ def __init__(self, model_name_or_path=None, api_url=None, api_token=None):
47
+ """
48
+ Initialize font classifier client.
49
+
50
+ Args:
51
+ model_name_or_path: Local model path or HuggingFace model name
52
+ api_url: Inference Endpoint URL (alternative to local model)
53
+ api_token: API token for Inference Endpoints
54
+ """
55
+ self.api_url = api_url
56
+ self.api_token = api_token
57
+
58
+ if api_url:
59
+ # Using Inference Endpoint
60
+ self.model = None
61
+ self.processor = None
62
+ self.headers = {
63
+ "Authorization": f"Bearer {api_token}",
64
+ "Content-Type": "application/json"
65
+ } if api_token else {}
66
+ else:
67
+ # Using local model
68
+ self.model = Dinov2ForImageClassification.from_pretrained(model_name_or_path)
69
+ self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
70
+ self.model.eval()
71
+
72
+ # Set up preprocessing transform
73
+ self.preprocess_transform = T.Compose([
74
+ T.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
75
+ pad_to_square,
76
+ T.Resize((224, 224)),
77
+ T.ToTensor(),
78
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
79
+ ])
80
+
81
+ def preprocess_image(self, image):
82
+ """Apply correct preprocessing to image."""
83
+ if isinstance(image, str):
84
+ image = Image.open(image)
85
+ return self.preprocess_transform(image)
86
+
87
+ def predict_local(self, image, top_k=5):
88
+ """Make prediction using local model."""
89
+ if self.model is None:
90
+ raise ValueError("No local model loaded")
91
+
92
+ # Preprocess image
93
+ processed_image = self.preprocess_image(image)
94
+ pixel_values = processed_image.unsqueeze(0) # Add batch dimension
95
+
96
+ # Get prediction
97
+ with torch.no_grad():
98
+ outputs = self.model(pixel_values=pixel_values)
99
+ logits = outputs.logits
100
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
101
+
102
+ # Get top-k predictions
103
+ top_k_indices = torch.topk(logits, k=top_k).indices[0]
104
+ top_k_labels = [self.model.config.id2label[idx.item()] for idx in top_k_indices]
105
+ top_k_confidences = [probabilities[0][idx].item() for idx in top_k_indices]
106
+
107
+ return list(zip(top_k_labels, top_k_confidences))
108
+
109
+ def predict_api(self, image, top_k=5):
110
+ """Make prediction using Inference Endpoint API."""
111
+ if not self.api_url:
112
+ raise ValueError("No API URL provided")
113
+
114
+ # Preprocess image
115
+ processed_image = self.preprocess_image(image)
116
+
117
+ # Convert to PIL for API transmission
118
+ processed_pil = T.ToPILImage()(processed_image)
119
+
120
+ # Convert to bytes
121
+ img_buffer = io.BytesIO()
122
+ processed_pil.save(img_buffer, format='PNG')
123
+ img_bytes = img_buffer.getvalue()
124
+
125
+ # Encode as base64
126
+ img_base64 = base64.b64encode(img_bytes).decode()
127
+
128
+ # Make API request
129
+ payload = {
130
+ "inputs": img_base64,
131
+ "parameters": {"top_k": top_k}
132
+ }
133
+
134
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
135
+ response.raise_for_status()
136
+
137
+ results = response.json()
138
+
139
+ # Format results
140
+ if isinstance(results, list) and len(results) > 0:
141
+ predictions = [(item["label"], item["score"]) for item in results[:top_k]]
142
+ return predictions
143
+ else:
144
+ raise ValueError(f"Unexpected API response format: {results}")
145
+
146
+ def predict(self, image, top_k=5):
147
+ """
148
+ Make prediction with automatic backend selection.
149
+
150
+ Args:
151
+ image: PIL Image, file path, or numpy array
152
+ top_k: Number of top predictions to return
153
+
154
+ Returns:
155
+ List of (label, confidence) tuples
156
+ """
157
+ if self.api_url:
158
+ return self.predict_api(image, top_k)
159
+ else:
160
+ return self.predict_local(image, top_k)
161
+
162
+ @classmethod
163
+ def from_local_model(cls, model_name_or_path):
164
+ """Create client for local model."""
165
+ return cls(model_name_or_path=model_name_or_path)
166
+
167
+ @classmethod
168
+ def from_inference_endpoint(cls, api_url, api_token=None):
169
+ """Create client for Inference Endpoint."""
170
+ return cls(api_url=api_url, api_token=api_token)
171
+
172
+ # Convenience functions
173
+ def predict_font_local(model_name, image_path, top_k=5):
174
+ """Quick prediction with local model."""
175
+ client = FontClassifierClient.from_local_model(model_name)
176
+ return client.predict(image_path, top_k)
177
+
178
+ def predict_font_api(api_url, image_path, api_token=None, top_k=5):
179
+ """Quick prediction with Inference Endpoint."""
180
+ client = FontClassifierClient.from_inference_endpoint(api_url, api_token)
181
+ return client.predict(image_path, top_k)
182
+
183
+ # Example usage:
184
+ if __name__ == "__main__":
185
+ # Local usage
186
+ # client = FontClassifierClient.from_local_model("dchen0/font-classifier-v4")
187
+ # results = client.predict("test_image.png")
188
+
189
+ # API usage
190
+ # client = FontClassifierClient.from_inference_endpoint("https://your-endpoint.com")
191
+ # results = client.predict("test_image.png")
192
+
193
+ print("Font Classifier Client ready. Use FontClassifierClient.from_local_model() or FontClassifierClient.from_inference_endpoint()")
preprocessor_config.json CHANGED
@@ -13,7 +13,7 @@
13
  0.456,
14
  0.406
15
  ],
16
- "image_processor_type": "FontClassifierImageProcessor",
17
  "image_std": [
18
  0.229,
19
  0.224,
@@ -23,8 +23,5 @@
23
  "rescale_factor": 0.00392156862745098,
24
  "size": {
25
  "shortest_edge": 256
26
- },
27
- "auto_map": {
28
- "AutoImageProcessor": "font_classifier_processor.FontClassifierImageProcessor"
29
  }
30
- }
 
13
  0.456,
14
  0.406
15
  ],
16
+ "image_processor_type": "BitImageProcessor",
17
  "image_std": [
18
  0.229,
19
  0.224,
 
23
  "rescale_factor": 0.00392156862745098,
24
  "size": {
25
  "shortest_edge": 256
 
 
 
26
  }
27
+ }