| """ |
| Client-side wrapper for font classification with proper preprocessing. |
| Works with both local models and HuggingFace Inference Endpoints. |
| """ |
| import base64 |
| import io |
|
|
| import numpy as np |
| import requests |
| import torch |
| import torchvision.transforms as T |
| from PIL import Image |
| from transformers import AutoImageProcessor, Dinov2ForImageClassification |
|
|
|
|
| def pad_to_square(image): |
| """ |
| Pad image to square while preserving aspect ratio. |
| This is the crucial preprocessing step for font classification. |
| """ |
| if isinstance(image, torch.Tensor): |
| |
| if image.dim() == 4: |
| image = image.squeeze(0) |
| image = T.ToPILImage()(image) |
| |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| |
| if not isinstance(image, Image.Image): |
| raise ValueError(f"Expected PIL Image, got {type(image)}") |
| |
| w, h = image.size |
| max_size = max(w, h) |
| pad_w = (max_size - w) // 2 |
| pad_h = (max_size - h) // 2 |
| padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
| return T.Pad(padding, fill=0)(image) |
|
|
| class FontClassifierClient: |
| """ |
| Client for font classification that ensures correct preprocessing. |
| Works with both local models and Inference Endpoints. |
| """ |
| |
| def __init__(self, model_name_or_path=None, api_url=None, api_token=None): |
| """ |
| Initialize font classifier client. |
| |
| Args: |
| model_name_or_path: Local model path or HuggingFace model name |
| api_url: Inference Endpoint URL (alternative to local model) |
| api_token: API token for Inference Endpoints |
| """ |
| self.api_url = api_url |
| self.api_token = api_token |
| |
| if api_url: |
| |
| self.model = None |
| self.processor = None |
| self.headers = { |
| "Authorization": f"Bearer {api_token}", |
| "Content-Type": "application/json" |
| } if api_token else {} |
| else: |
| |
| self.model = Dinov2ForImageClassification.from_pretrained(model_name_or_path) |
| self.processor = AutoImageProcessor.from_pretrained(model_name_or_path) |
| self.model.eval() |
| |
| |
| self.preprocess_transform = T.Compose([ |
| T.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x), |
| pad_to_square, |
| T.Resize((224, 224)), |
| T.ToTensor(), |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| def preprocess_image(self, image): |
| """Apply correct preprocessing to image.""" |
| if isinstance(image, str): |
| image = Image.open(image) |
| return self.preprocess_transform(image) |
| |
| def predict_local(self, image, top_k=5): |
| """Make prediction using local model.""" |
| if self.model is None: |
| raise ValueError("No local model loaded") |
| |
| |
| processed_image = self.preprocess_image(image) |
| pixel_values = processed_image.unsqueeze(0) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(pixel_values=pixel_values) |
| logits = outputs.logits |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) |
| |
| |
| top_k_indices = torch.topk(logits, k=top_k).indices[0] |
| top_k_labels = [self.model.config.id2label[idx.item()] for idx in top_k_indices] |
| top_k_confidences = [probabilities[0][idx].item() for idx in top_k_indices] |
| |
| return list(zip(top_k_labels, top_k_confidences)) |
| |
| def predict_api(self, image, top_k=5): |
| """Make prediction using Inference Endpoint API.""" |
| if not self.api_url: |
| raise ValueError("No API URL provided") |
| |
| |
| processed_image = self.preprocess_image(image) |
| |
| |
| processed_pil = T.ToPILImage()(processed_image) |
| |
| |
| img_buffer = io.BytesIO() |
| processed_pil.save(img_buffer, format='PNG') |
| img_bytes = img_buffer.getvalue() |
| |
| |
| img_base64 = base64.b64encode(img_bytes).decode() |
| |
| |
| payload = { |
| "inputs": img_base64, |
| "parameters": {"top_k": top_k} |
| } |
| |
| response = requests.post(self.api_url, headers=self.headers, json=payload) |
| response.raise_for_status() |
| |
| results = response.json() |
| |
| |
| if isinstance(results, list) and len(results) > 0: |
| predictions = [(item["label"], item["score"]) for item in results[:top_k]] |
| return predictions |
| else: |
| raise ValueError(f"Unexpected API response format: {results}") |
| |
| def predict(self, image, top_k=5): |
| """ |
| Make prediction with automatic backend selection. |
| |
| Args: |
| image: PIL Image, file path, or numpy array |
| top_k: Number of top predictions to return |
| |
| Returns: |
| List of (label, confidence) tuples |
| """ |
| if self.api_url: |
| return self.predict_api(image, top_k) |
| else: |
| return self.predict_local(image, top_k) |
| |
| @classmethod |
| def from_local_model(cls, model_name_or_path): |
| """Create client for local model.""" |
| return cls(model_name_or_path=model_name_or_path) |
| |
| @classmethod |
| def from_inference_endpoint(cls, api_url, api_token=None): |
| """Create client for Inference Endpoint.""" |
| return cls(api_url=api_url, api_token=api_token) |
|
|
| |
| def predict_font_local(model_name, image_path, top_k=5): |
| """Quick prediction with local model.""" |
| client = FontClassifierClient.from_local_model(model_name) |
| return client.predict(image_path, top_k) |
|
|
| def predict_font_api(api_url, image_path, api_token=None, top_k=5): |
| """Quick prediction with Inference Endpoint.""" |
| client = FontClassifierClient.from_inference_endpoint(api_url, api_token) |
| return client.predict(image_path, top_k) |
|
|
| |
| if __name__ == "__main__": |
| |
| |
| |
| |
| |
| |
| |
| |
| print("Font Classifier Client ready. Use FontClassifierClient.from_local_model() or FontClassifierClient.from_inference_endpoint()") |