|
|
""" |
|
|
Client-side wrapper for font classification with proper preprocessing. |
|
|
Works with both local models and HuggingFace Inference Endpoints. |
|
|
""" |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
import numpy as np |
|
|
import requests |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, Dinov2ForImageClassification |
|
|
|
|
|
|
|
|
def pad_to_square(image): |
|
|
""" |
|
|
Pad image to square while preserving aspect ratio. |
|
|
This is the crucial preprocessing step for font classification. |
|
|
""" |
|
|
if isinstance(image, torch.Tensor): |
|
|
|
|
|
if image.dim() == 4: |
|
|
image = image.squeeze(0) |
|
|
image = T.ToPILImage()(image) |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
raise ValueError(f"Expected PIL Image, got {type(image)}") |
|
|
|
|
|
w, h = image.size |
|
|
max_size = max(w, h) |
|
|
pad_w = (max_size - w) // 2 |
|
|
pad_h = (max_size - h) // 2 |
|
|
padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
|
|
return T.Pad(padding, fill=0)(image) |
|
|
|
|
|
class FontClassifierClient: |
|
|
""" |
|
|
Client for font classification that ensures correct preprocessing. |
|
|
Works with both local models and Inference Endpoints. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name_or_path=None, api_url=None, api_token=None): |
|
|
""" |
|
|
Initialize font classifier client. |
|
|
|
|
|
Args: |
|
|
model_name_or_path: Local model path or HuggingFace model name |
|
|
api_url: Inference Endpoint URL (alternative to local model) |
|
|
api_token: API token for Inference Endpoints |
|
|
""" |
|
|
self.api_url = api_url |
|
|
self.api_token = api_token |
|
|
|
|
|
if api_url: |
|
|
|
|
|
self.model = None |
|
|
self.processor = None |
|
|
self.headers = { |
|
|
"Authorization": f"Bearer {api_token}", |
|
|
"Content-Type": "application/json" |
|
|
} if api_token else {} |
|
|
else: |
|
|
|
|
|
self.model = Dinov2ForImageClassification.from_pretrained(model_name_or_path) |
|
|
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.preprocess_transform = T.Compose([ |
|
|
T.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x), |
|
|
pad_to_square, |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def preprocess_image(self, image): |
|
|
"""Apply correct preprocessing to image.""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image) |
|
|
return self.preprocess_transform(image) |
|
|
|
|
|
def predict_local(self, image, top_k=5): |
|
|
"""Make prediction using local model.""" |
|
|
if self.model is None: |
|
|
raise ValueError("No local model loaded") |
|
|
|
|
|
|
|
|
processed_image = self.preprocess_image(image) |
|
|
pixel_values = processed_image.unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(pixel_values=pixel_values) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
top_k_indices = torch.topk(logits, k=top_k).indices[0] |
|
|
top_k_labels = [self.model.config.id2label[idx.item()] for idx in top_k_indices] |
|
|
top_k_confidences = [probabilities[0][idx].item() for idx in top_k_indices] |
|
|
|
|
|
return list(zip(top_k_labels, top_k_confidences)) |
|
|
|
|
|
def predict_api(self, image, top_k=5): |
|
|
"""Make prediction using Inference Endpoint API.""" |
|
|
if not self.api_url: |
|
|
raise ValueError("No API URL provided") |
|
|
|
|
|
|
|
|
processed_image = self.preprocess_image(image) |
|
|
|
|
|
|
|
|
processed_pil = T.ToPILImage()(processed_image) |
|
|
|
|
|
|
|
|
img_buffer = io.BytesIO() |
|
|
processed_pil.save(img_buffer, format='PNG') |
|
|
img_bytes = img_buffer.getvalue() |
|
|
|
|
|
|
|
|
img_base64 = base64.b64encode(img_bytes).decode() |
|
|
|
|
|
|
|
|
payload = { |
|
|
"inputs": img_base64, |
|
|
"parameters": {"top_k": top_k} |
|
|
} |
|
|
|
|
|
response = requests.post(self.api_url, headers=self.headers, json=payload) |
|
|
response.raise_for_status() |
|
|
|
|
|
results = response.json() |
|
|
|
|
|
|
|
|
if isinstance(results, list) and len(results) > 0: |
|
|
predictions = [(item["label"], item["score"]) for item in results[:top_k]] |
|
|
return predictions |
|
|
else: |
|
|
raise ValueError(f"Unexpected API response format: {results}") |
|
|
|
|
|
def predict(self, image, top_k=5): |
|
|
""" |
|
|
Make prediction with automatic backend selection. |
|
|
|
|
|
Args: |
|
|
image: PIL Image, file path, or numpy array |
|
|
top_k: Number of top predictions to return |
|
|
|
|
|
Returns: |
|
|
List of (label, confidence) tuples |
|
|
""" |
|
|
if self.api_url: |
|
|
return self.predict_api(image, top_k) |
|
|
else: |
|
|
return self.predict_local(image, top_k) |
|
|
|
|
|
@classmethod |
|
|
def from_local_model(cls, model_name_or_path): |
|
|
"""Create client for local model.""" |
|
|
return cls(model_name_or_path=model_name_or_path) |
|
|
|
|
|
@classmethod |
|
|
def from_inference_endpoint(cls, api_url, api_token=None): |
|
|
"""Create client for Inference Endpoint.""" |
|
|
return cls(api_url=api_url, api_token=api_token) |
|
|
|
|
|
|
|
|
def predict_font_local(model_name, image_path, top_k=5): |
|
|
"""Quick prediction with local model.""" |
|
|
client = FontClassifierClient.from_local_model(model_name) |
|
|
return client.predict(image_path, top_k) |
|
|
|
|
|
def predict_font_api(api_url, image_path, api_token=None, top_k=5): |
|
|
"""Quick prediction with Inference Endpoint.""" |
|
|
client = FontClassifierClient.from_inference_endpoint(api_url, api_token) |
|
|
return client.predict(image_path, top_k) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Font Classifier Client ready. Use FontClassifierClient.from_local_model() or FontClassifierClient.from_inference_endpoint()") |