Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,10 +32,19 @@ class ChineseClassifier(nn.Module):
|
|
| 32 |
return x
|
| 33 |
|
| 34 |
# -------- Utility Functions --------
|
| 35 |
-
def
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def prepare_transforms():
|
| 41 |
return transforms.Compose([
|
|
@@ -64,13 +73,19 @@ def load_model(model_path, embed_dim, num_classes, device, pretrained=True, unfr
|
|
| 64 |
# -------- Globals and Setup --------
|
| 65 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
EMBED_DIM = 512
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
# Download model weights from HF repo
|
| 75 |
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
|
| 76 |
print("Downloading model from repo...")
|
|
|
|
| 32 |
return x
|
| 33 |
|
| 34 |
# -------- Utility Functions --------
|
| 35 |
+
def load_class_list(labels_txt_path):
|
| 36 |
+
"""Load ordered list of classes (characters) from labels.txt"""
|
| 37 |
+
with open(labels_txt_path, "r", encoding="utf-8") as f:
|
| 38 |
+
classes = [line.strip() for line in f if line.strip()]
|
| 39 |
+
return classes
|
| 40 |
+
|
| 41 |
+
def load_labels_json(labels_json_path):
|
| 42 |
+
"""Load dict mapping image filename -> character label"""
|
| 43 |
+
with open(labels_json_path, "r", encoding="utf-8") as f:
|
| 44 |
+
labels_dict = json.load(f)
|
| 45 |
+
# Normalize Windows-style backslash paths to slash paths
|
| 46 |
+
labels_dict = {k.replace("\\", "/"): v for k, v in labels_dict.items()}
|
| 47 |
+
return labels_dict
|
| 48 |
|
| 49 |
def prepare_transforms():
|
| 50 |
return transforms.Compose([
|
|
|
|
| 73 |
# -------- Globals and Setup --------
|
| 74 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
|
| 76 |
+
# Paths to your label files (make sure these are accessible in your environment)
|
| 77 |
+
LABELS_TXT_PATH = "labels.txt" # Your class list: idx -> character
|
| 78 |
+
LABELS_JSON_PATH = "labels.json" # Your filename -> character mapping (optional, for evaluation)
|
| 79 |
+
|
| 80 |
+
# Load class list for prediction indexing
|
| 81 |
+
classes = load_class_list(LABELS_TXT_PATH)
|
| 82 |
+
idx_to_class = {idx: c for idx, c in enumerate(classes)}
|
| 83 |
+
num_classes = len(classes)
|
| 84 |
EMBED_DIM = 512
|
| 85 |
|
| 86 |
+
# Load the labels.json if you want (not required for prediction)
|
| 87 |
+
# filename_to_char = load_labels_json(LABELS_JSON_PATH)
|
| 88 |
+
|
| 89 |
# Download model weights from HF repo
|
| 90 |
REPO_ID = "JJJHHHH/CCR_EthicalSplit_Finetune"
|
| 91 |
print("Downloading model from repo...")
|