| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer, AutoModel |
| | from datetime import datetime |
| | import json |
| | import os |
| | import math |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | MODEL_DIR = 'model' |
| | FULL_MODEL_PATH = os.path.join(MODEL_DIR, 'cascaded_best.pt') |
| | CONFIG_PATH = os.path.join(MODEL_DIR, 'model_config.json') |
| | TOKENIZER_PATH = os.path.join(MODEL_DIR, 'tokenizer') |
| | BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model') |
| |
|
| | DICT_2 = os.path.join(MODEL_DIR, 'label2id_2.json') |
| | DICT_4 = os.path.join(MODEL_DIR, 'label2id_4.json') |
| | DICT_6 = os.path.join(MODEL_DIR, 'label2id_6.json') |
| |
|
| | RESULTS_PATH = os.path.join(MODEL_DIR, 'test_results.txt') |
| |
|
| |
|
| | class ArcMarginProduct(nn.Module): |
| | """ArcFace classifier (inference mode: no margin, just cosine * scale).""" |
| | def __init__(self, in_features, out_features, s=30.0, m=0.30): |
| | super().__init__() |
| | self.s = s |
| | self.m = m |
| | self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) |
| | nn.init.xavier_uniform_(self.weight) |
| | self.cos_m = math.cos(m) |
| | self.sin_m = math.sin(m) |
| | self.th = math.cos(math.pi - m) |
| | self.mm = math.sin(math.pi - m) * m |
| |
|
| | def forward(self, x, label=None): |
| | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) |
| | if label is not None and self.training: |
| | sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1)) |
| | phi = cosine * self.cos_m - sine * self.sin_m |
| | phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
| | one_hot = torch.zeros_like(cosine) |
| | one_hot.scatter_(1, label.view(-1, 1).long(), 1) |
| | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
| | return output * self.s |
| | return cosine * self.s |
| |
|
| |
|
| | class CascadedClassifier(nn.Module): |
| | """3-level cascaded classifier: 2 β 4 β 6 with ArcFace on level 6.""" |
| | def __init__(self, base_model, hidden_size, n2, n4, n6, |
| | dropout=0.15, arc_s=30.0, arc_m=0.3): |
| | super().__init__() |
| | self.base_model = base_model |
| | self.drop = nn.Dropout(dropout) |
| |
|
| | self.head_2 = nn.Sequential( |
| | nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(), |
| | nn.Dropout(dropout), nn.Linear(256, n2)) |
| |
|
| | self.head_4_fusion = nn.Linear(hidden_size + n2, hidden_size) |
| | self.head_4 = nn.Sequential( |
| | nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), |
| | nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, n4)) |
| |
|
| | self.head_6_fusion = nn.Linear(hidden_size + n4, hidden_size) |
| | self.head_6_feat = nn.Sequential( |
| | nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout), |
| | nn.Linear(hidden_size, 512), nn.GELU()) |
| | self.head_6_arc = ArcMarginProduct(512, n6, s=arc_s, m=arc_m) |
| |
|
| | def forward(self, input_ids, attention_mask, label_6=None): |
| | out = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
| | cls_out = self.drop(out.last_hidden_state[:, 0, :]) |
| |
|
| | l2 = self.head_2(cls_out) |
| | p2 = torch.softmax(l2, dim=1) |
| | f4 = self.head_4_fusion(torch.cat([cls_out, p2], dim=1)) |
| | l4 = self.head_4(f4) |
| | p4 = torch.softmax(l4, dim=1) |
| | f6 = self.head_6_fusion(torch.cat([cls_out, p4], dim=1)) |
| | feat6 = self.head_6_feat(f6) |
| | l6 = self.head_6_arc(feat6, label_6) |
| | return l2, l4, l6 |
| |
|
| |
|
| | def save_result(filepath, text, candidates, cascade_2, cascade_4): |
| | """Append a single test result to the results txt file.""" |
| | with open(filepath, 'a', encoding='utf-8') as f: |
| | f.write(f"\n{'='*80}\n") |
| | f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| | f.write(f"Input: {text}\n") |
| | f.write(f"Cascade: {cascade_2} β {cascade_4}\n") |
| | f.write(f"{'-'*80}\n") |
| | f.write(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain\n") |
| | f.write(f"{'-'*80}\n") |
| | for i, c in enumerate(candidates[:5]): |
| | cd = c['code'] |
| | ch = f"{cd[:2]}({c['p2']:.2f})β{cd[:4]}({c['p4']:.2f})β{cd[:6]}({c['p6']:.2f})" |
| | f.write(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}\n") |
| | f.write(f"{'-'*80}\n") |
| | if candidates[0]['score'] > 1e-3: |
| | f.write("β
Strong match.\n") |
| | elif candidates[0]['p6'] < 0.1: |
| | f.write("β οΈ Low confidence.\n") |
| |
|
| |
|
| | def main(): |
| | print("Loading bert-base-uncased FULL FT + ArcFace model (3-level, 6-digit)...") |
| |
|
| | if not os.path.exists(CONFIG_PATH): |
| | print(f"Config not found: {CONFIG_PATH}. Train first.") |
| | return |
| |
|
| | try: |
| | config = json.load(open(CONFIG_PATH)) |
| | model_name = config['model_name'] |
| | hidden_size = config['hidden_size'] |
| | max_seq_len = config['max_seq_len'] |
| | counts = config['classes'] |
| | dropout = config.get('dropout', 0.15) |
| | arc_s = config.get('arcface_scale', 30.0) |
| | arc_m = config.get('arcface_margin', 0.3) |
| |
|
| | l2id_2 = json.load(open(DICT_2)) |
| | l2id_4 = json.load(open(DICT_4)) |
| | l2id_6 = json.load(open(DICT_6)) |
| |
|
| | id2l_2 = {v: k for k, v in l2id_2.items()} |
| | id2l_4 = {v: k for k, v in l2id_4.items()} |
| | id2l_6 = {v: k for k, v in l2id_6.items()} |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) |
| |
|
| | if os.path.exists(BASE_MODEL_PATH): |
| | base_model = AutoModel.from_pretrained(BASE_MODEL_PATH) |
| | else: |
| | base_model = AutoModel.from_pretrained(model_name) |
| |
|
| | model = CascadedClassifier( |
| | base_model=base_model, hidden_size=hidden_size, |
| | n2=counts['n2'], n4=counts['n4'], n6=counts['n6'], |
| | dropout=dropout, arc_s=arc_s, arc_m=arc_m |
| | ).to(device) |
| |
|
| | if os.path.exists(FULL_MODEL_PATH): |
| | state_dict = torch.load(FULL_MODEL_PATH, map_location=device) |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | model.eval() |
| | print(f"Loaded. Best val acc: {config.get('best_val_acc_6', 'N/A')}%") |
| | print(f"Mode: {config.get('training_mode', 'N/A')}") |
| |
|
| | except Exception as e: |
| | print(f"Error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return |
| |
|
| | |
| | with open(RESULTS_PATH, 'a', encoding='utf-8') as f: |
| | f.write(f"\n{'#'*80}\n") |
| | f.write(f"Test session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") |
| | f.write(f"Model: {config.get('model_name', 'N/A')}\n") |
| | f.write(f"Architecture: {config.get('architecture', 'N/A')}\n") |
| | f.write(f"Best val acc (6-digit): {config.get('best_val_acc_6', 'N/A')}%\n") |
| | f.write(f"{'#'*80}\n") |
| |
|
| | print(f"\nπ Results will be saved to: {RESULTS_PATH}") |
| | print("\n--- HS Code Classification (3-level, 6-digit) ---") |
| | print("Type description or 'q' to quit.\n") |
| |
|
| | while True: |
| | try: |
| | text = input("Description: ") |
| | except (KeyboardInterrupt, EOFError): |
| | break |
| | if text.lower() in ('q', 'quit', 'exit') or not text.strip(): |
| | if not text.strip(): |
| | continue |
| | break |
| |
|
| | enc = tokenizer(text, max_length=max_seq_len, padding='max_length', |
| | truncation=True, return_tensors='pt') |
| | ids = enc['input_ids'].to(device) |
| | mask = enc['attention_mask'].to(device) |
| |
|
| | with torch.no_grad(): |
| | with torch.amp.autocast('cuda'): |
| | o2, o4, o6 = model(ids, mask) |
| |
|
| | p2 = F.softmax(o2, dim=1) |
| | p4 = F.softmax(o4, dim=1) |
| | p6 = F.softmax(o6, dim=1) |
| |
|
| | _, b2 = torch.max(p2, 1) |
| | b2c = id2l_2.get(b2.item(), "") |
| | _, b4 = torch.max(p4, 1) |
| | b4c = id2l_4.get(b4.item(), "") |
| |
|
| | top_p, top_i = torch.topk(p6, 10, dim=1) |
| |
|
| | candidates = [] |
| | for j in range(10): |
| | idx = top_i[0][j].item() |
| | prob6 = top_p[0][j].item() |
| | code6 = id2l_6.get(idx, "Unk") |
| |
|
| | def get_prob(code_str, mapper, probs): |
| | for k, v in mapper.items(): |
| | if v == code_str: |
| | return probs[0][k].item() |
| | return 0.0 |
| |
|
| | pr2 = get_prob(code6[:2], id2l_2, p2) |
| | pr4 = get_prob(code6[:4], id2l_4, p4) |
| |
|
| | eps = 1e-6 |
| | score = (prob6**2) * ((pr4+eps)**0.5) * ((pr2+eps)**0.5) |
| | if code6.startswith(b4c): |
| | score *= 10.0 |
| | elif code6[:2] == b2c: |
| | score *= 5.0 |
| |
|
| | candidates.append({"code": code6, "score": score, "p6": prob6, |
| | "p4": pr4, "p2": pr2}) |
| |
|
| | candidates.sort(key=lambda x: x["score"], reverse=True) |
| |
|
| | print(f"\n Cascade: {b2c} β {b4c}") |
| | print("-" * 80) |
| | print(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain") |
| | print("-" * 80) |
| | for i in range(min(5, len(candidates))): |
| | c = candidates[i] |
| | cd = c["code"] |
| | ch = f"{cd[:2]}({c['p2']:.2f})β{cd[:4]}({c['p4']:.2f})β{cd[:6]}({c['p6']:.2f})" |
| | print(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}") |
| | print("-" * 80) |
| |
|
| | if candidates[0]['score'] > 1e-3: |
| | print("β
Strong match.") |
| | elif candidates[0]['p6'] < 0.1: |
| | print("β οΈ Low confidence.") |
| |
|
| | |
| | save_result(RESULTS_PATH, text, candidates, b2c, b4c) |
| | print(f" π Saved to {RESULTS_PATH}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|