JJJHHHH commited on
Commit
0f763d7
·
verified ·
1 Parent(s): cfd1837

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -32,10 +32,19 @@ class ChineseClassifier(nn.Module):
32
  return x
33
 
34
  # -------- Utility Functions --------
35
- def load_labels(labels_path):
36
- with open(labels_path, "r", encoding="utf-8") as f:
37
- labels = [line.strip() for line in f if line.strip()]
38
- return labels
 
 
 
 
 
 
 
 
 
39
 
40
  def prepare_transforms():
41
  return transforms.Compose([
@@ -64,13 +73,19 @@ def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfr
64
  # -------- Globals and Setup --------
65
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
 
67
- # Load labels from JSON
68
- labels_path = "labels.txt"
69
- labels = load_labels(labels_path) # List of characters ordered by index
70
- idx_to_class = {idx: label for idx, label in enumerate(labels)}
71
- num_classes = len(labels)
 
 
 
72
  EMBED_DIM = 512
73
 
 
 
 
74
  # Download model weights from HF repo
75
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
76
  print("Downloading model from repo...")
 
32
  return x
33
 
34
  # -------- Utility Functions --------
35
+ def load_class_list(labels_txt_path):
36
+ """Load ordered list of classes (characters) from labels.txt"""
37
+ with open(labels_txt_path, "r", encoding="utf-8") as f:
38
+ classes = [line.strip() for line in f if line.strip()]
39
+ return classes
40
+
41
+ def load_labels_json(labels_json_path):
42
+ """Load dict mapping image filename -> character label"""
43
+ with open(labels_json_path, "r", encoding="utf-8") as f:
44
+ labels_dict = json.load(f)
45
+ # Normalize Windows-style backslash paths to slash paths
46
+ labels_dict = {k.replace("\\", "/"): v for k, v in labels_dict.items()}
47
+ return labels_dict
48
 
49
  def prepare_transforms():
50
  return transforms.Compose([
 
73
  # -------- Globals and Setup --------
74
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
 
76
+ # Paths to your label files (make sure these are accessible in your environment)
77
+ LABELS_TXT_PATH = "labels.txt" # Your class list: idx -> character
78
+ LABELS_JSON_PATH = "labels.json" # Your filename -> character mapping (optional, for evaluation)
79
+
80
+ # Load class list for prediction indexing
81
+ classes = load_class_list(LABELS_TXT_PATH)
82
+ idx_to_class = {idx: c for idx, c in enumerate(classes)}
83
+ num_classes = len(classes)
84
  EMBED_DIM = 512
85
 
86
+ # Load the labels.json if you want (not required for prediction)
87
+ # filename_to_char = load_labels_json(LABELS_JSON_PATH)
88
+
89
  # Download model weights from HF repo
90
  REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
91
  print("Downloading model from repo...")