ENTUM-AI commited on
Commit
5645844
·
verified ·
1 Parent(s): 0eaf29a

Initial upload: HS Code Classifier (English, 6-digit)

Browse files
Files changed (1) hide show
  1. inference.py +254 -0
inference.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from datetime import datetime
6
+ import json
7
+ import os
8
+ import math
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ MODEL_DIR = 'models_4090_eng_6digit'
13
+ FULL_MODEL_PATH = os.path.join(MODEL_DIR, 'cascaded_best.pt')
14
+ CONFIG_PATH = os.path.join(MODEL_DIR, 'model_config.json')
15
+ TOKENIZER_PATH = os.path.join(MODEL_DIR, 'tokenizer')
16
+ BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model')
17
+
18
+ DICT_2 = os.path.join(MODEL_DIR, 'label2id_2.json')
19
+ DICT_4 = os.path.join(MODEL_DIR, 'label2id_4.json')
20
+ DICT_6 = os.path.join(MODEL_DIR, 'label2id_6.json')
21
+
22
+ RESULTS_PATH = os.path.join(MODEL_DIR, 'test_results.txt')
23
+
24
+
25
+ class ArcMarginProduct(nn.Module):
26
+ """ArcFace classifier (inference mode: no margin, just cosine * scale)."""
27
+ def __init__(self, in_features, out_features, s=30.0, m=0.30):
28
+ super().__init__()
29
+ self.s = s
30
+ self.m = m
31
+ self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
32
+ nn.init.xavier_uniform_(self.weight)
33
+ self.cos_m = math.cos(m)
34
+ self.sin_m = math.sin(m)
35
+ self.th = math.cos(math.pi - m)
36
+ self.mm = math.sin(math.pi - m) * m
37
+
38
+ def forward(self, x, label=None):
39
+ cosine = F.linear(F.normalize(x), F.normalize(self.weight))
40
+ if label is not None and self.training:
41
+ sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1))
42
+ phi = cosine * self.cos_m - sine * self.sin_m
43
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
44
+ one_hot = torch.zeros_like(cosine)
45
+ one_hot.scatter_(1, label.view(-1, 1).long(), 1)
46
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
47
+ return output * self.s
48
+ return cosine * self.s
49
+
50
+
51
+ class CascadedClassifier(nn.Module):
52
+ """3-level cascaded classifier: 2 → 4 → 6 with ArcFace on level 6."""
53
+ def __init__(self, base_model, hidden_size, n2, n4, n6,
54
+ dropout=0.15, arc_s=30.0, arc_m=0.3):
55
+ super().__init__()
56
+ self.base_model = base_model
57
+ self.drop = nn.Dropout(dropout)
58
+
59
+ self.head_2 = nn.Sequential(
60
+ nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(),
61
+ nn.Dropout(dropout), nn.Linear(256, n2))
62
+
63
+ self.head_4_fusion = nn.Linear(hidden_size + n2, hidden_size)
64
+ self.head_4 = nn.Sequential(
65
+ nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout),
66
+ nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, n4))
67
+
68
+ self.head_6_fusion = nn.Linear(hidden_size + n4, hidden_size)
69
+ self.head_6_feat = nn.Sequential(
70
+ nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout),
71
+ nn.Linear(hidden_size, 512), nn.GELU())
72
+ self.head_6_arc = ArcMarginProduct(512, n6, s=arc_s, m=arc_m)
73
+
74
+ def forward(self, input_ids, attention_mask, label_6=None):
75
+ out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
76
+ cls_out = self.drop(out.last_hidden_state[:, 0, :])
77
+
78
+ l2 = self.head_2(cls_out)
79
+ p2 = torch.softmax(l2, dim=1)
80
+ f4 = self.head_4_fusion(torch.cat([cls_out, p2], dim=1))
81
+ l4 = self.head_4(f4)
82
+ p4 = torch.softmax(l4, dim=1)
83
+ f6 = self.head_6_fusion(torch.cat([cls_out, p4], dim=1))
84
+ feat6 = self.head_6_feat(f6)
85
+ l6 = self.head_6_arc(feat6, label_6)
86
+ return l2, l4, l6
87
+
88
+
89
+ def save_result(filepath, text, candidates, cascade_2, cascade_4):
90
+ """Append a single test result to the results txt file."""
91
+ with open(filepath, 'a', encoding='utf-8') as f:
92
+ f.write(f"\n{'='*80}\n")
93
+ f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
94
+ f.write(f"Input: {text}\n")
95
+ f.write(f"Cascade: {cascade_2} → {cascade_4}\n")
96
+ f.write(f"{'-'*80}\n")
97
+ f.write(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain\n")
98
+ f.write(f"{'-'*80}\n")
99
+ for i, c in enumerate(candidates[:5]):
100
+ cd = c['code']
101
+ ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})"
102
+ f.write(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}\n")
103
+ f.write(f"{'-'*80}\n")
104
+ if candidates[0]['score'] > 1e-3:
105
+ f.write("✅ Strong match.\n")
106
+ elif candidates[0]['p6'] < 0.1:
107
+ f.write("⚠️ Low confidence.\n")
108
+
109
+
110
+ def main():
111
+ print("Loading bert-base-uncased FULL FT + ArcFace model (3-level, 6-digit)...")
112
+
113
+ if not os.path.exists(CONFIG_PATH):
114
+ print(f"Config not found: {CONFIG_PATH}. Train first.")
115
+ return
116
+
117
+ try:
118
+ config = json.load(open(CONFIG_PATH))
119
+ model_name = config['model_name']
120
+ hidden_size = config['hidden_size']
121
+ max_seq_len = config['max_seq_len']
122
+ counts = config['classes']
123
+ dropout = config.get('dropout', 0.15)
124
+ arc_s = config.get('arcface_scale', 30.0)
125
+ arc_m = config.get('arcface_margin', 0.3)
126
+
127
+ l2id_2 = json.load(open(DICT_2))
128
+ l2id_4 = json.load(open(DICT_4))
129
+ l2id_6 = json.load(open(DICT_6))
130
+
131
+ id2l_2 = {v: k for k, v in l2id_2.items()}
132
+ id2l_4 = {v: k for k, v in l2id_4.items()}
133
+ id2l_6 = {v: k for k, v in l2id_6.items()}
134
+
135
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
136
+
137
+ if os.path.exists(BASE_MODEL_PATH):
138
+ base_model = AutoModel.from_pretrained(BASE_MODEL_PATH)
139
+ else:
140
+ base_model = AutoModel.from_pretrained(model_name)
141
+
142
+ model = CascadedClassifier(
143
+ base_model=base_model, hidden_size=hidden_size,
144
+ n2=counts['n2'], n4=counts['n4'], n6=counts['n6'],
145
+ dropout=dropout, arc_s=arc_s, arc_m=arc_m
146
+ ).to(device)
147
+
148
+ if os.path.exists(FULL_MODEL_PATH):
149
+ state_dict = torch.load(FULL_MODEL_PATH, map_location=device)
150
+ model.load_state_dict(state_dict, strict=False)
151
+
152
+ model.eval()
153
+ print(f"Loaded. Best val acc: {config.get('best_val_acc_6', 'N/A')}%")
154
+ print(f"Mode: {config.get('training_mode', 'N/A')}")
155
+
156
+ except Exception as e:
157
+ print(f"Error: {e}")
158
+ import traceback
159
+ traceback.print_exc()
160
+ return
161
+
162
+ # Initialize results file
163
+ with open(RESULTS_PATH, 'a', encoding='utf-8') as f:
164
+ f.write(f"\n{'#'*80}\n")
165
+ f.write(f"Test session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
166
+ f.write(f"Model: {config.get('model_name', 'N/A')}\n")
167
+ f.write(f"Architecture: {config.get('architecture', 'N/A')}\n")
168
+ f.write(f"Best val acc (6-digit): {config.get('best_val_acc_6', 'N/A')}%\n")
169
+ f.write(f"{'#'*80}\n")
170
+
171
+ print(f"\n📝 Results will be saved to: {RESULTS_PATH}")
172
+ print("\n--- HS Code Classification (3-level, 6-digit) ---")
173
+ print("Type description or 'q' to quit.\n")
174
+
175
+ while True:
176
+ try:
177
+ text = input("Description: ")
178
+ except (KeyboardInterrupt, EOFError):
179
+ break
180
+ if text.lower() in ('q', 'quit', 'exit') or not text.strip():
181
+ if not text.strip():
182
+ continue
183
+ break
184
+
185
+ enc = tokenizer(text, max_length=max_seq_len, padding='max_length',
186
+ truncation=True, return_tensors='pt')
187
+ ids = enc['input_ids'].to(device)
188
+ mask = enc['attention_mask'].to(device)
189
+
190
+ with torch.no_grad():
191
+ with torch.amp.autocast('cuda'):
192
+ o2, o4, o6 = model(ids, mask)
193
+
194
+ p2 = F.softmax(o2, dim=1)
195
+ p4 = F.softmax(o4, dim=1)
196
+ p6 = F.softmax(o6, dim=1)
197
+
198
+ _, b2 = torch.max(p2, 1)
199
+ b2c = id2l_2.get(b2.item(), "")
200
+ _, b4 = torch.max(p4, 1)
201
+ b4c = id2l_4.get(b4.item(), "")
202
+
203
+ top_p, top_i = torch.topk(p6, 10, dim=1)
204
+
205
+ candidates = []
206
+ for j in range(10):
207
+ idx = top_i[0][j].item()
208
+ prob6 = top_p[0][j].item()
209
+ code6 = id2l_6.get(idx, "Unk")
210
+
211
+ def get_prob(code_str, mapper, probs):
212
+ for k, v in mapper.items():
213
+ if v == code_str:
214
+ return probs[0][k].item()
215
+ return 0.0
216
+
217
+ pr2 = get_prob(code6[:2], id2l_2, p2)
218
+ pr4 = get_prob(code6[:4], id2l_4, p4)
219
+
220
+ eps = 1e-6
221
+ score = (prob6**2) * ((pr4+eps)**0.5) * ((pr2+eps)**0.5)
222
+ if code6.startswith(b4c):
223
+ score *= 10.0
224
+ elif code6[:2] == b2c:
225
+ score *= 5.0
226
+
227
+ candidates.append({"code": code6, "score": score, "p6": prob6,
228
+ "p4": pr4, "p2": pr2})
229
+
230
+ candidates.sort(key=lambda x: x["score"], reverse=True)
231
+
232
+ print(f"\n Cascade: {b2c} → {b4c}")
233
+ print("-" * 80)
234
+ print(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain")
235
+ print("-" * 80)
236
+ for i in range(min(5, len(candidates))):
237
+ c = candidates[i]
238
+ cd = c["code"]
239
+ ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})"
240
+ print(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}")
241
+ print("-" * 80)
242
+
243
+ if candidates[0]['score'] > 1e-3:
244
+ print("✅ Strong match.")
245
+ elif candidates[0]['p6'] < 0.1:
246
+ print("⚠️ Low confidence.")
247
+
248
+ # Save result to txt file
249
+ save_result(RESULTS_PATH, text, candidates, b2c, b4c)
250
+ print(f" 📝 Saved to {RESULTS_PATH}")
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()