JJJHHHH commited on
Commit
3b4987c
·
verified ·
1 Parent(s): 0f763d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -32,19 +32,16 @@ class ChineseClassifier(nn.Module):
32
  return x
33
 
34
  # -------- Utility Functions --------
35
- def load_class_list(labels_txt_path):
36
- """Load ordered list of classes (characters) from labels.txt"""
37
- with open(labels_txt_path, "r", encoding="utf-8") as f:
38
- classes = [line.strip() for line in f if line.strip()]
39
- return classes
40
 
41
  def load_labels_json(labels_json_path):
42
- """Load dict mapping image filename -> character label"""
43
  with open(labels_json_path, "r", encoding="utf-8") as f:
44
  labels_dict = json.load(f)
45
- # Normalize Windows-style backslash paths to slash paths
46
- labels_dict = {k.replace("\\", "/"): v for k, v in labels_dict.items()}
47
- return labels_dict
48
 
49
  def prepare_transforms():
50
  return transforms.Compose([
@@ -70,35 +67,35 @@ def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfr
70
  model.eval()
71
  return model
72
 
73
- # -------- Globals and Setup --------
74
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
75
 
76
- # Paths to your label files (make sure these are accessible in your environment)
77
- LABELS_TXT_PATH = "labels.txt" # Your class list: idx -> character
78
- LABELS_JSON_PATH = "labels.json" # Your filename -> character mapping (optional, for evaluation)
79
-
80
- # Load class list for prediction indexing
81
- classes = load_class_list(LABELS_TXT_PATH)
82
  idx_to_class = {idx: c for idx, c in enumerate(classes)}
83
  num_classes = len(classes)
84
- EMBED_DIM = 512
85
 
86
- # Load the labels.json if you want (not required for prediction)
87
- # filename_to_char = load_labels_json(LABELS_JSON_PATH)
 
88
 
89
- # Download model weights from HF repo
90
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
91
  print("Downloading model from repo...")
92
  repo_dir = snapshot_download(repo_id=REPO_ID)
93
  model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
94
- print("Model path:", model_path)
95
 
96
- # Load model and transforms
97
  model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
98
  transform = prepare_transforms()
99
 
100
  # -------- Prediction Function --------
101
  def predict(pil_img):
 
102
  img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
103
  with torch.no_grad():
104
  output = model(img_t)
@@ -112,5 +109,5 @@ gr.Interface(
112
  inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"),
113
  outputs=gr.Text(label="Predicted Character"),
114
  title="Chinese Character Recognition",
115
- description="Upload an image of a handwritten Chinese character and get the predicted character."
116
- ).launch(share=True)
 
32
  return x
33
 
34
  # -------- Utility Functions --------
35
+ def get_sorted_classes(labels_dict):
36
+ """Extract sorted unique classes from labels dictionary"""
37
+ return sorted(set(labels_dict.values()))
 
 
38
 
39
  def load_labels_json(labels_json_path):
40
+ """Load and normalize labels JSON"""
41
  with open(labels_json_path, "r", encoding="utf-8") as f:
42
  labels_dict = json.load(f)
43
+ # Normalize paths and remove directory prefixes
44
+ return {os.path.basename(k).replace("\\", "/"): v for k, v in labels_dict.items()}
 
45
 
46
  def prepare_transforms():
47
  return transforms.Compose([
 
67
  model.eval()
68
  return model
69
 
70
+ # -------- Setup --------
71
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ EMBED_DIM = 512
73
+ LABELS_JSON_PATH = "labels.json"
74
 
75
+ # 1. Load labels and extract sorted classes
76
+ labels_dict = load_labels_json(LABELS_JSON_PATH)
77
+ classes = get_sorted_classes(labels_dict)
 
 
 
78
  idx_to_class = {idx: c for idx, c in enumerate(classes)}
79
  num_classes = len(classes)
 
80
 
81
+ # Verify class count matches training
82
+ print(f"Loaded {num_classes} classes")
83
+ print(f"First 5 classes: {classes[:5]}")
84
 
85
+ # 2. Download model
86
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
87
  print("Downloading model from repo...")
88
  repo_dir = snapshot_download(repo_id=REPO_ID)
89
  model_path = os.path.join(repo_dir, "CCR_EthicalSplit_Finetune.pth")
90
+ print(f"Model path: {model_path}")
91
 
92
+ # 3. Load model
93
  model = load_model(model_path, EMBED_DIM, num_classes, DEVICE)
94
  transform = prepare_transforms()
95
 
96
  # -------- Prediction Function --------
97
  def predict(pil_img):
98
+ """Predict character from PIL image"""
99
  img_t = transform(pil_img).unsqueeze(0).to(DEVICE)
100
  with torch.no_grad():
101
  output = model(img_t)
 
109
  inputs=gr.Image(type="pil", label="Upload Handwritten Chinese Character"),
110
  outputs=gr.Text(label="Predicted Character"),
111
  title="Chinese Character Recognition",
112
+ description="Recognizes handwritten Chinese characters with 80% accuracy",
113
+ ).launch()