JJJHHHH commited on
Commit
faac63e
·
verified ·
1 Parent(s): 585a1e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -35
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import os
 
 
2
  import torch
3
  import torch.nn as nn
4
- from PIL import Image
5
  from torchvision import models, transforms
6
  from huggingface_hub import snapshot_download
7
  import gradio as gr
8
 
9
- # --------------------------
10
- # Model Architecture
11
- # --------------------------
12
  class ChineseClassifier(nn.Module):
13
  def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True):
14
  super().__init__()
15
- resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrainedEncoder else None)
16
  self.resnet = nn.Sequential(*list(resnet.children())[:-1])
17
  for param in self.resnet.parameters():
18
  param.requires_grad = unfreezeEncoder
@@ -29,14 +28,15 @@ class ChineseClassifier(nn.Module):
29
  x = self.dropout(x)
30
  if return_embedding:
31
  return x
32
- return self.classifier(x)
 
33
 
34
- # --------------------------
35
- # Helper Functions
36
- # --------------------------
37
- def load_labels(path):
38
- with open(path, "r", encoding="utf-8") as f:
39
- return [line.strip() for line in f if line.strip()]
40
 
41
  def prepare_transforms():
42
  return transforms.Compose([
@@ -46,41 +46,56 @@ def prepare_transforms():
46
  std=[0.229, 0.224, 0.225]),
47
  ])
48
 
49
- def load_model(path, embed_dim, num_classes, device):
50
- model = ChineseClassifier(embed_dim, num_classes).to(device)
51
- model.load_state_dict(torch.load(path, map_location=device))
 
 
 
 
 
 
 
 
 
 
52
  model.eval()
53
  return model
54
 
55
- # --------------------------
56
- # Init
57
- # --------------------------
 
 
 
 
 
 
 
 
 
 
 
58
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
 
59
  repo_dir = snapshot_download(repo_id=REPO_ID)
60
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
-
62
- labels_path = os.path.join(repo_dir, "labels.txt")
63
  model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
 
64
 
65
- class_names = load_labels(labels_path)
 
66
  transform = prepare_transforms()
67
- model = load_model(model_path, embed_dim=512, num_classes=len(class_names), device=device)
68
 
69
- # --------------------------
70
- # Inference
71
- # --------------------------
72
- def predict(image: Image.Image):
73
- image = image.convert("RGB")
74
- input_tensor = transform(image).unsqueeze(0).to(device)
75
  with torch.no_grad():
76
- output = model(input_tensor)
77
  pred_idx = output.argmax(dim=1).item()
78
- pred_label = class_names[pred_idx]
79
- return f"Prediction: {pred_label}"
80
 
81
- # --------------------------
82
- # Gradio UI
83
- # --------------------------
84
  gr.Interface(
85
  fn=predict,
86
  inputs=gr.Image(type="pil", label="Upload Image"),
 
1
  import os
2
+ import json
3
+ from PIL import Image
4
  import torch
5
  import torch.nn as nn
 
6
  from torchvision import models, transforms
7
  from huggingface_hub import snapshot_download
8
  import gradio as gr
9
 
10
+ # -------- Model Definition --------
 
 
11
  class ChineseClassifier(nn.Module):
12
  def __init__(self, embed_dim, num_classes, pretrainedEncoder=True, unfreezeEncoder=True):
13
  super().__init__()
14
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) if pretrainedEncoder else models.resnet50()
15
  self.resnet = nn.Sequential(*list(resnet.children())[:-1])
16
  for param in self.resnet.parameters():
17
  param.requires_grad = unfreezeEncoder
 
28
  x = self.dropout(x)
29
  if return_embedding:
30
  return x
31
+ x = self.classifier(x)
32
+ return x
33
 
34
+ # -------- Utility Functions --------
35
+ def load_labels(labels_path):
36
+ # If your labels.txt is json-like, else adjust accordingly
37
+ with open(labels_path, "r", encoding="utf-8") as f:
38
+ labels = json.load(f)
39
+ return labels
40
 
41
  def prepare_transforms():
42
  return transforms.Compose([
 
46
  std=[0.229, 0.224, 0.225]),
47
  ])
48
 
49
+ def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfreeze=True):
50
+ model = ChineseClassifier(embed_dim, num_classes, pretrainedEncoder=pretrained, unfreezeEncoder=unfreeze).to(device)
51
+ checkpoint = torch.load(model_path, map_location=device)
52
+ if "model_state_dict" in checkpoint:
53
+ try:
54
+ model.load_state_dict(checkpoint["model_state_dict"])
55
+ except RuntimeError as e:
56
+ print("Warning:", e)
57
+ print("Loading partial weights, skipping classifier layer...")
58
+ filtered_state_dict = {k: v for k, v in checkpoint["model_state_dict"].items() if not k.startswith("classifier.")}
59
+ model.load_state_dict(filtered_state_dict, strict=False)
60
+ else:
61
+ model.load_state_dict(checkpoint)
62
  model.eval()
63
  return model
64
 
65
+ # -------- Globals and Setup --------
66
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+
68
+ # Load labels locally from Space repo root
69
+ labels_path = "labels.txt"
70
+ labels_dict = load_labels(labels_path)
71
+ # Create list sorted by index (assuming labels_dict: filename->label)
72
+ classes = sorted(set(labels_dict.values()))
73
+ class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
74
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
75
+ num_classes = len(classes)
76
+ EMBED_DIM = 512
77
+
78
+ # Download model weights from HF repo
79
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
80
+ print("Downloading model from HF repo...")
81
  repo_dir = snapshot_download(repo_id=REPO_ID)
 
 
 
82
  model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
83
+ print("Model path:", model_path)
84
 
85
+ # Prepare model and transforms
86
+ model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
87
  transform = prepare_transforms()
 
88
 
89
+ # -------- Prediction Function --------
90
+ def predict(pil_img):
91
+ img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
 
 
 
92
  with torch.no_grad():
93
+ output = model(img_t)
94
  pred_idx = output.argmax(dim=1).item()
95
+ pred_label = idx_to_class[pred_idx]
96
+ return pred_label
97
 
98
+ # -------- Gradio Interface --------
 
 
99
  gr.Interface(
100
  fn=predict,
101
  inputs=gr.Image(type="pil", label="Upload Image"),