Spaces:
Sleeping
Sleeping
| import json | |
| import joblib | |
| import numpy as np | |
| from PIL import Image | |
| class LRModel: | |
| """ | |
| Inference pipeline for Logistic Regression model | |
| trained on 64x64 grayscale flattened images. | |
| """ | |
| def __init__(self, model_path: str, labels_path: str, image_size: int = 64): | |
| self.model = joblib.load(model_path) | |
| self.labels = self._load_labels(labels_path) | |
| self.image_size = image_size | |
| def _load_labels(self, labels_path): | |
| with open(labels_path, "r") as f: | |
| label_dict = json.load(f) | |
| # Ensure keys are integer indices, not strings | |
| label_dict = {int(k): v for k, v in label_dict.items()} | |
| return label_dict | |
| def preprocess(self, image: Image.Image) -> np.ndarray: | |
| """ | |
| Preprocessing matching training: | |
| - Resize to 64x64 | |
| - Grayscale | |
| - Normalize to [0,1] | |
| - Flatten to (1, D) | |
| """ | |
| img = image.resize((self.image_size, self.image_size)) | |
| img = img.convert("L") # grayscale | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| arr = arr.reshape(1, -1) # shape: (1, D) | |
| return arr | |
| def predict(self, image: Image.Image, top_k: int = 5): | |
| """ | |
| Returns: | |
| { | |
| "class_id": int, | |
| "class_name": str, | |
| "probabilities": {class_name: prob}, | |
| "top_k": [ | |
| {"class_id": int, "class_name": str, "probability": float}, | |
| ... | |
| ] | |
| } | |
| """ | |
| x = self.preprocess(image) | |
| probs = self.model.predict_proba(x)[0] | |
| pred_id = int(np.argmax(probs)) | |
| pred_name = self.labels[pred_id] | |
| prob_dict = { | |
| self.labels[i]: float(probs[i]) for i in range(len(probs)) | |
| } | |
| # Top-k (sorted) | |
| sorted_indices = np.argsort(probs)[::-1] | |
| top_k = min(top_k, len(sorted_indices)) | |
| top_k_list = [] | |
| for i in range(top_k): | |
| cid = int(sorted_indices[i]) | |
| top_k_list.append({ | |
| "class_id": cid, | |
| "class_name": self.labels[cid], | |
| "probability": float(probs[cid]), | |
| }) | |
| return { | |
| "class_id": pred_id, | |
| "class_name": pred_name, | |
| "probabilities": prob_dict, | |
| "top_k": top_k_list, | |
| } | |