File size: 6,911 Bytes
518728c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
Client-side wrapper for font classification with proper preprocessing.
Works with both local models and HuggingFace Inference Endpoints.
"""
import base64
import io

import numpy as np
import requests
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import AutoImageProcessor, Dinov2ForImageClassification


def pad_to_square(image):
    """
    Pad image to square while preserving aspect ratio.
    This is the crucial preprocessing step for font classification.
    """
    if isinstance(image, torch.Tensor):
        # Convert tensor to PIL for processing
        if image.dim() == 4:  # Batch dimension
            image = image.squeeze(0)
        image = T.ToPILImage()(image)
    
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    if not isinstance(image, Image.Image):
        raise ValueError(f"Expected PIL Image, got {type(image)}")
    
    w, h = image.size
    max_size = max(w, h)
    pad_w = (max_size - w) // 2
    pad_h = (max_size - h) // 2
    padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
    return T.Pad(padding, fill=0)(image)

class FontClassifierClient:
    """
    Client for font classification that ensures correct preprocessing.
    Works with both local models and Inference Endpoints.
    """
    
    def __init__(self, model_name_or_path=None, api_url=None, api_token=None):
        """
        Initialize font classifier client.
        
        Args:
            model_name_or_path: Local model path or HuggingFace model name
            api_url: Inference Endpoint URL (alternative to local model)
            api_token: API token for Inference Endpoints
        """
        self.api_url = api_url
        self.api_token = api_token
        
        if api_url:
            # Using Inference Endpoint
            self.model = None
            self.processor = None
            self.headers = {
                "Authorization": f"Bearer {api_token}",
                "Content-Type": "application/json"
            } if api_token else {}
        else:
            # Using local model
            self.model = Dinov2ForImageClassification.from_pretrained(model_name_or_path)
            self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
            self.model.eval()
        
        # Set up preprocessing transform
        self.preprocess_transform = T.Compose([
            T.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            pad_to_square,
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def preprocess_image(self, image):
        """Apply correct preprocessing to image."""
        if isinstance(image, str):
            image = Image.open(image)
        return self.preprocess_transform(image)
    
    def predict_local(self, image, top_k=5):
        """Make prediction using local model."""
        if self.model is None:
            raise ValueError("No local model loaded")
        
        # Preprocess image
        processed_image = self.preprocess_image(image)
        pixel_values = processed_image.unsqueeze(0)  # Add batch dimension
        
        # Get prediction
        with torch.no_grad():
            outputs = self.model(pixel_values=pixel_values)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=-1)
        
        # Get top-k predictions
        top_k_indices = torch.topk(logits, k=top_k).indices[0]
        top_k_labels = [self.model.config.id2label[idx.item()] for idx in top_k_indices]
        top_k_confidences = [probabilities[0][idx].item() for idx in top_k_indices]
        
        return list(zip(top_k_labels, top_k_confidences))
    
    def predict_api(self, image, top_k=5):
        """Make prediction using Inference Endpoint API."""
        if not self.api_url:
            raise ValueError("No API URL provided")
        
        # Preprocess image
        processed_image = self.preprocess_image(image)
        
        # Convert to PIL for API transmission
        processed_pil = T.ToPILImage()(processed_image)
        
        # Convert to bytes
        img_buffer = io.BytesIO()
        processed_pil.save(img_buffer, format='PNG')
        img_bytes = img_buffer.getvalue()
        
        # Encode as base64
        img_base64 = base64.b64encode(img_bytes).decode()
        
        # Make API request
        payload = {
            "inputs": img_base64,
            "parameters": {"top_k": top_k}
        }
        
        response = requests.post(self.api_url, headers=self.headers, json=payload)
        response.raise_for_status()
        
        results = response.json()
        
        # Format results
        if isinstance(results, list) and len(results) > 0:
            predictions = [(item["label"], item["score"]) for item in results[:top_k]]
            return predictions
        else:
            raise ValueError(f"Unexpected API response format: {results}")
    
    def predict(self, image, top_k=5):
        """
        Make prediction with automatic backend selection.
        
        Args:
            image: PIL Image, file path, or numpy array
            top_k: Number of top predictions to return
            
        Returns:
            List of (label, confidence) tuples
        """
        if self.api_url:
            return self.predict_api(image, top_k)
        else:
            return self.predict_local(image, top_k)
    
    @classmethod
    def from_local_model(cls, model_name_or_path):
        """Create client for local model."""
        return cls(model_name_or_path=model_name_or_path)
    
    @classmethod
    def from_inference_endpoint(cls, api_url, api_token=None):
        """Create client for Inference Endpoint."""
        return cls(api_url=api_url, api_token=api_token)

# Convenience functions
def predict_font_local(model_name, image_path, top_k=5):
    """Quick prediction with local model."""
    client = FontClassifierClient.from_local_model(model_name)
    return client.predict(image_path, top_k)

def predict_font_api(api_url, image_path, api_token=None, top_k=5):
    """Quick prediction with Inference Endpoint."""
    client = FontClassifierClient.from_inference_endpoint(api_url, api_token)
    return client.predict(image_path, top_k)

# Example usage:
if __name__ == "__main__":
    # Local usage
    # client = FontClassifierClient.from_local_model("dchen0/font-classifier-v4")
    # results = client.predict("test_image.png")
    
    # API usage
    # client = FontClassifierClient.from_inference_endpoint("https://your-endpoint.com")
    # results = client.predict("test_image.png")
    
    print("Font Classifier Client ready. Use FontClassifierClient.from_local_model() or FontClassifierClient.from_inference_endpoint()")