Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from huggingface_hub import snapshot_download | |
| import gradio as gr | |
| # -------- Model Definition -------- | |
| class ChineseClassifier(nn.Module): | |
| def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True): | |
| super().__init__() | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50() | |
| self.resnet = nn.Sequential(*list(resnet.children())[:-1]) | |
| for param in self.resnet.parameters(): | |
| param.requires_grad = unfreezeEncoder | |
| self.fc = nn.Linear(resnet.fc.in_features, embed_dim) | |
| self.batch_norm = nn.BatchNorm1d(embed_dim) | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(embed_dim, num_classes) | |
| def forward(self, x, return_embedding=False): | |
| x = self.resnet(x) | |
| x = torch.flatten(x, 1) | |
| x = self.fc(x) | |
| x = self.batch_norm(x) | |
| x = self.dropout(x) | |
| if return_embedding: | |
| return x | |
| x = self.classifier(x) | |
| return x | |
| # -------- Utility Functions -------- | |
| def get_sorted_classes(labels_dict): | |
| """Extract sorted unique classes from labels dictionary""" | |
| return sorted(set(labels_dict.values())) | |
| def load_labels_json(labels_json_path): | |
| """Load and normalize labels JSON""" | |
| with open(labels_json_path, "r", encoding="utf-8") as f: | |
| labels_dict = json.load(f) | |
| # Normalize paths and remove directory prefixes | |
| return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()} | |
| def prepare_transforms(): | |
| return transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True): | |
| model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if "model_state_dict" in checkpoint: | |
| try: | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| except RuntimeError as e: | |
| print("Warning:", e) | |
| print("Loading partial weights, skipping classifier layer...") | |
| filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")} | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| return model | |
| # -------- Setup -------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| EMBED_DIM = 512 | |
| LABELS_JSON_PATH = "labels.json" | |
| # 1. Load labels and extract sorted classes | |
| labels_dict = load_labels_json(LABELS_JSON_PATH) | |
| classes = get_sorted_classes(labels_dict) | |
| idx_to_class = {idx: c for idx, c in enumerate(classes)} | |
| num_classes = len(classes) | |
| # Verify class count matches training | |
| print(f"Loaded {num_classes} classes") | |
| print(f"First 5 classes: {classes[:5]}") | |
| # 2. Download model | |
| REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune" | |
| print("Downloading model from repo...") | |
| repo_dir = snapshot_download(repo_id=REPO_ID) | |
| model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth") | |
| print(f"Model path: {model_path}") | |
| # 3. Load model | |
| model = load_model(model_path, EMBED_DIM, num_classes, DEVICE) | |
| transform = prepare_transforms() | |
| # -------- Prediction Function -------- | |
| def predict(pil_img): | |
| """Predict character from PIL image""" | |
| img_t = transform(pil_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(img_t) | |
| pred_idx = output.argmax(dim=1).item() | |
| pred_label = idx_to_class[pred_idx] | |
| return pred_label | |
| # -------- Gradio Interface -------- | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"), | |
| outputs=gr.Text(label="Predicted Character"), | |
| title="Chinese Character Recognition", | |
| description="Recognizes handwritten Chinese characters with 80% accuracy", | |
| ).launch() |