hscode-classifier-en / inference.py
ENTUM-AI's picture
Initial upload: HS Code Classifier (English, 6-digit)
ccb9055 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from datetime import datetime
import json
import os
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_DIR = 'model'
FULL_MODEL_PATH = os.path.join(MODEL_DIR, 'cascaded_best.pt')
CONFIG_PATH = os.path.join(MODEL_DIR, 'model_config.json')
TOKENIZER_PATH = os.path.join(MODEL_DIR, 'tokenizer')
BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model')
DICT_2 = os.path.join(MODEL_DIR, 'label2id_2.json')
DICT_4 = os.path.join(MODEL_DIR, 'label2id_4.json')
DICT_6 = os.path.join(MODEL_DIR, 'label2id_6.json')
RESULTS_PATH = os.path.join(MODEL_DIR, 'test_results.txt')
class ArcMarginProduct(nn.Module):
"""ArcFace classifier (inference mode: no margin, just cosine * scale)."""
def __init__(self, in_features, out_features, s=30.0, m=0.30):
super().__init__()
self.s = s
self.m = m
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, x, label=None):
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
if label is not None and self.training:
sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
return output * self.s
return cosine * self.s
class CascadedClassifier(nn.Module):
"""3-level cascaded classifier: 2 β†’ 4 β†’ 6 with ArcFace on level 6."""
def __init__(self, base_model, hidden_size, n2, n4, n6,
dropout=0.15, arc_s=30.0, arc_m=0.3):
super().__init__()
self.base_model = base_model
self.drop = nn.Dropout(dropout)
self.head_2 = nn.Sequential(
nn.Linear(hidden_size, 256), nn.LayerNorm(256), nn.GELU(),
nn.Dropout(dropout), nn.Linear(256, n2))
self.head_4_fusion = nn.Linear(hidden_size + n2, hidden_size)
self.head_4 = nn.Sequential(
nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout),
nn.Linear(hidden_size, 256), nn.GELU(), nn.Linear(256, n4))
self.head_6_fusion = nn.Linear(hidden_size + n4, hidden_size)
self.head_6_feat = nn.Sequential(
nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(dropout),
nn.Linear(hidden_size, 512), nn.GELU())
self.head_6_arc = ArcMarginProduct(512, n6, s=arc_s, m=arc_m)
def forward(self, input_ids, attention_mask, label_6=None):
out = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
cls_out = self.drop(out.last_hidden_state[:, 0, :])
l2 = self.head_2(cls_out)
p2 = torch.softmax(l2, dim=1)
f4 = self.head_4_fusion(torch.cat([cls_out, p2], dim=1))
l4 = self.head_4(f4)
p4 = torch.softmax(l4, dim=1)
f6 = self.head_6_fusion(torch.cat([cls_out, p4], dim=1))
feat6 = self.head_6_feat(f6)
l6 = self.head_6_arc(feat6, label_6)
return l2, l4, l6
def save_result(filepath, text, candidates, cascade_2, cascade_4):
"""Append a single test result to the results txt file."""
with open(filepath, 'a', encoding='utf-8') as f:
f.write(f"\n{'='*80}\n")
f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Input: {text}\n")
f.write(f"Cascade: {cascade_2} β†’ {cascade_4}\n")
f.write(f"{'-'*80}\n")
f.write(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain\n")
f.write(f"{'-'*80}\n")
for i, c in enumerate(candidates[:5]):
cd = c['code']
ch = f"{cd[:2]}({c['p2']:.2f})β†’{cd[:4]}({c['p4']:.2f})β†’{cd[:6]}({c['p6']:.2f})"
f.write(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}\n")
f.write(f"{'-'*80}\n")
if candidates[0]['score'] > 1e-3:
f.write("βœ… Strong match.\n")
elif candidates[0]['p6'] < 0.1:
f.write("⚠️ Low confidence.\n")
def main():
print("Loading bert-base-uncased FULL FT + ArcFace model (3-level, 6-digit)...")
if not os.path.exists(CONFIG_PATH):
print(f"Config not found: {CONFIG_PATH}. Train first.")
return
try:
config = json.load(open(CONFIG_PATH))
model_name = config['model_name']
hidden_size = config['hidden_size']
max_seq_len = config['max_seq_len']
counts = config['classes']
dropout = config.get('dropout', 0.15)
arc_s = config.get('arcface_scale', 30.0)
arc_m = config.get('arcface_margin', 0.3)
l2id_2 = json.load(open(DICT_2))
l2id_4 = json.load(open(DICT_4))
l2id_6 = json.load(open(DICT_6))
id2l_2 = {v: k for k, v in l2id_2.items()}
id2l_4 = {v: k for k, v in l2id_4.items()}
id2l_6 = {v: k for k, v in l2id_6.items()}
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
if os.path.exists(BASE_MODEL_PATH):
base_model = AutoModel.from_pretrained(BASE_MODEL_PATH)
else:
base_model = AutoModel.from_pretrained(model_name)
model = CascadedClassifier(
base_model=base_model, hidden_size=hidden_size,
n2=counts['n2'], n4=counts['n4'], n6=counts['n6'],
dropout=dropout, arc_s=arc_s, arc_m=arc_m
).to(device)
if os.path.exists(FULL_MODEL_PATH):
state_dict = torch.load(FULL_MODEL_PATH, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()
print(f"Loaded. Best val acc: {config.get('best_val_acc_6', 'N/A')}%")
print(f"Mode: {config.get('training_mode', 'N/A')}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return
# Initialize results file
with open(RESULTS_PATH, 'a', encoding='utf-8') as f:
f.write(f"\n{'#'*80}\n")
f.write(f"Test session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Model: {config.get('model_name', 'N/A')}\n")
f.write(f"Architecture: {config.get('architecture', 'N/A')}\n")
f.write(f"Best val acc (6-digit): {config.get('best_val_acc_6', 'N/A')}%\n")
f.write(f"{'#'*80}\n")
print(f"\nπŸ“ Results will be saved to: {RESULTS_PATH}")
print("\n--- HS Code Classification (3-level, 6-digit) ---")
print("Type description or 'q' to quit.\n")
while True:
try:
text = input("Description: ")
except (KeyboardInterrupt, EOFError):
break
if text.lower() in ('q', 'quit', 'exit') or not text.strip():
if not text.strip():
continue
break
enc = tokenizer(text, max_length=max_seq_len, padding='max_length',
truncation=True, return_tensors='pt')
ids = enc['input_ids'].to(device)
mask = enc['attention_mask'].to(device)
with torch.no_grad():
with torch.amp.autocast('cuda'):
o2, o4, o6 = model(ids, mask)
p2 = F.softmax(o2, dim=1)
p4 = F.softmax(o4, dim=1)
p6 = F.softmax(o6, dim=1)
_, b2 = torch.max(p2, 1)
b2c = id2l_2.get(b2.item(), "")
_, b4 = torch.max(p4, 1)
b4c = id2l_4.get(b4.item(), "")
top_p, top_i = torch.topk(p6, 10, dim=1)
candidates = []
for j in range(10):
idx = top_i[0][j].item()
prob6 = top_p[0][j].item()
code6 = id2l_6.get(idx, "Unk")
def get_prob(code_str, mapper, probs):
for k, v in mapper.items():
if v == code_str:
return probs[0][k].item()
return 0.0
pr2 = get_prob(code6[:2], id2l_2, p2)
pr4 = get_prob(code6[:4], id2l_4, p4)
eps = 1e-6
score = (prob6**2) * ((pr4+eps)**0.5) * ((pr2+eps)**0.5)
if code6.startswith(b4c):
score *= 10.0
elif code6[:2] == b2c:
score *= 5.0
candidates.append({"code": code6, "score": score, "p6": prob6,
"p4": pr4, "p2": pr2})
candidates.sort(key=lambda x: x["score"], reverse=True)
print(f"\n Cascade: {b2c} β†’ {b4c}")
print("-" * 80)
print(f"{'#':<4} | {'Code':<12} | {'Score':<10} | {'P(6)':<8} | Chain")
print("-" * 80)
for i in range(min(5, len(candidates))):
c = candidates[i]
cd = c["code"]
ch = f"{cd[:2]}({c['p2']:.2f})β†’{cd[:4]}({c['p4']:.2f})β†’{cd[:6]}({c['p6']:.2f})"
print(f"{i+1:<4} | {cd:<12} | {c['score']:.2e} | {c['p6']:.4f} | {ch}")
print("-" * 80)
if candidates[0]['score'] > 1e-3:
print("βœ… Strong match.")
elif candidates[0]['p6'] < 0.1:
print("⚠️ Low confidence.")
# Save result to txt file
save_result(RESULTS_PATH, text, candidates, b2c, b4c)
print(f" πŸ“ Saved to {RESULTS_PATH}")
if __name__ == "__main__":
main()