import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from datetime import datetime import json import os import math device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODEL_DIR = 'model' FULL_MODEL_PATH = os.path.join(MODEL_DIR, 'cascaded_best.pt') CONFIG_PATH = os.path.join(MODEL_DIR, 'model_config.json') TOKENIZER_PATH = os.path.join(MODEL_DIR, 'tokenizer') BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model') DICT_2 = os.path.join(MODEL_DIR, 'label2id_2.json') DICT_4 = os.path.join(MODEL_DIR, 'label2id_4.json') DICT_6 = os.path.join(MODEL_DIR, 'label2id_6.json') RESULTS_PATH = os.path.join(MODEL_DIR, 'test_results.txt') class ArcMarginProduct(nn.Module): """ArcFace classifier (inference mode: no margin, just cosine * scale).""" def __init__(self, in_features, out_features, s=30.0, m=0.30): super().__init__() self.s = s self.m = m self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m def forward(self, x, label=None): cosine = F.linear(F.normalize(x), F.normalize(self.weight)) if label is not None and self.training: sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1)) phi = cosine * self.cos_m - sine * self.sin_m phi = torch.where(cosine > self.th, phi, cosine - self.mm) one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, label.view(-1, 1).long(), 1) output = (one_hot * phi) + ((1.0 - one_hot) * cosine) return output * self.s return cosine * self.s class CascadedClassifier(nn.Module): """3-level cascaded classifier: 2 → 4 → 6 with ArcFace on level 6.""" def __init__(self, base_model, hidden_size, n2, n4, n6, dropout=0.15, arc_s=30.0, arc_m=0.3): super().__init__() self.base_model = base_model self.drop = nn.Dropout(dropout) self.head_2 = nn.Sequential( nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, n2)) self.head_4_fusion = nn.Linear(hidden_size + n2, hidden_size) self.head_4 = nn.Sequential( nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, n4)) self.head_6_fusion = nn.Linear(hidden_size + n4, hidden_size) self.head_6_feat = nn.Sequential( nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_size, 512), nn.GELU()) self.head_6_arc = ArcMarginProduct(512, n6, s=arc_s, m=arc_m) def forward(self, input_ids, attention_mask, label_6=None): out = self.base_model(input_ids=input_ids, attention_mask=attention_mask) cls_out = self.drop(out.last_hidden_state[:, 0, :]) l2 = self.head_2(cls_out) p2 = torch.softmax(l2, dim=1) f4 = self.head_4_fusion(torch.cat([cls_out, p2], dim=1)) l4 = self.head_4(f4) p4 = torch.softmax(l4, dim=1) f6 = self.head_6_fusion(torch.cat([cls_out, p4], dim=1)) feat6 = self.head_6_feat(f6) l6 = self.head_6_arc(feat6, label_6) return l2, l4, l6 def save_result(filepath, text, candidates, cascade_2, cascade_4): """Append a single test result to the results txt file.""" with open(filepath, 'a', encoding='utf-8') as f: f.write(f"\n{'='*80}\n") f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Input: {text}\n") f.write(f"Cascade: {cascade_2} → {cascade_4}\n") f.write(f"{'-'*80}\n") f.write(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain\n") f.write(f"{'-'*80}\n") for i, c in enumerate(candidates[:5]): cd = c['code'] ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})" f.write(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}\n") f.write(f"{'-'*80}\n") if candidates[0]['score'] > 1e-3: f.write("✅ Strong match.\n") elif candidates[0]['p6'] < 0.1: f.write("⚠️ Low confidence.\n") def main(): print("Loading bert-base-uncased FULL FT + ArcFace model (3-level, 6-digit)...") if not os.path.exists(CONFIG_PATH): print(f"Config not found: {CONFIG_PATH}. Train first.") return try: config = json.load(open(CONFIG_PATH)) model_name = config['model_name'] hidden_size = config['hidden_size'] max_seq_len = config['max_seq_len'] counts = config['classes'] dropout = config.get('dropout', 0.15) arc_s = config.get('arcface_scale', 30.0) arc_m = config.get('arcface_margin', 0.3) l2id_2 = json.load(open(DICT_2)) l2id_4 = json.load(open(DICT_4)) l2id_6 = json.load(open(DICT_6)) id2l_2 = {v: k for k, v in l2id_2.items()} id2l_4 = {v: k for k, v in l2id_4.items()} id2l_6 = {v: k for k, v in l2id_6.items()} tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) if os.path.exists(BASE_MODEL_PATH): base_model = AutoModel.from_pretrained(BASE_MODEL_PATH) else: base_model = AutoModel.from_pretrained(model_name) model = CascadedClassifier( base_model=base_model, hidden_size=hidden_size, n2=counts['n2'], n4=counts['n4'], n6=counts['n6'], dropout=dropout, arc_s=arc_s, arc_m=arc_m ).to(device) if os.path.exists(FULL_MODEL_PATH): state_dict = torch.load(FULL_MODEL_PATH, map_location=device) model.load_state_dict(state_dict, strict=False) model.eval() print(f"Loaded. Best val acc: {config.get('best_val_acc_6', 'N/A')}%") print(f"Mode: {config.get('training_mode', 'N/A')}") except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() return # Initialize results file with open(RESULTS_PATH, 'a', encoding='utf-8') as f: f.write(f"\n{'#'*80}\n") f.write(f"Test session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Model: {config.get('model_name', 'N/A')}\n") f.write(f"Architecture: {config.get('architecture', 'N/A')}\n") f.write(f"Best val acc (6-digit): {config.get('best_val_acc_6', 'N/A')}%\n") f.write(f"{'#'*80}\n") print(f"\n📝 Results will be saved to: {RESULTS_PATH}") print("\n--- HS Code Classification (3-level, 6-digit) ---") print("Type description or 'q' to quit.\n") while True: try: text = input("Description: ") except (KeyboardInterrupt, EOFError): break if text.lower() in ('q', 'quit', 'exit') or not text.strip(): if not text.strip(): continue break enc = tokenizer(text, max_length=max_seq_len, padding='max_length', truncation=True, return_tensors='pt') ids = enc['input_ids'].to(device) mask = enc['attention_mask'].to(device) with torch.no_grad(): with torch.amp.autocast('cuda'): o2, o4, o6 = model(ids, mask) p2 = F.softmax(o2, dim=1) p4 = F.softmax(o4, dim=1) p6 = F.softmax(o6, dim=1) _, b2 = torch.max(p2, 1) b2c = id2l_2.get(b2.item(), "") _, b4 = torch.max(p4, 1) b4c = id2l_4.get(b4.item(), "") top_p, top_i = torch.topk(p6, 10, dim=1) candidates = [] for j in range(10): idx = top_i[0][j].item() prob6 = top_p[0][j].item() code6 = id2l_6.get(idx, "Unk") def get_prob(code_str, mapper, probs): for k, v in mapper.items(): if v == code_str: return probs[0][k].item() return 0.0 pr2 = get_prob(code6[:2], id2l_2, p2) pr4 = get_prob(code6[:4], id2l_4, p4) eps = 1e-6 score = (prob6**2) * ((pr4+eps)**0.5) * ((pr2+eps)**0.5) if code6.startswith(b4c): score *= 10.0 elif code6[:2] == b2c: score *= 5.0 candidates.append({"code": code6, "score": score, "p6": prob6, "p4": pr4, "p2": pr2}) candidates.sort(key=lambda x: x["score"], reverse=True) print(f"\n Cascade: {b2c} → {b4c}") print("-" * 80) print(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain") print("-" * 80) for i in range(min(5, len(candidates))): c = candidates[i] cd = c["code"] ch = f"{cd[:2]}({c['p2']:.2f})→{cd[:4]}({c['p4']:.2f})→{cd[:6]}({c['p6']:.2f})" print(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}") print("-" * 80) if candidates[0]['score'] > 1e-3: print("✅ Strong match.") elif candidates[0]['p6'] < 0.1: print("⚠️ Low confidence.") # Save result to txt file save_result(RESULTS_PATH, text, candidates, b2c, b4c) print(f" 📝 Saved to {RESULTS_PATH}") if __name__ == "__main__": main()