""" AIFinder Interactive Classifier Loads trained model and provides an interactive REPL for classifying text. Usage: python3 classify.py """ import os import sys import time import joblib import numpy as np import torch import torch.nn as nn from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS from model import AIFinderNet def load_models(): """Load all model components from the model directory.""" try: pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib")) provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib")) checkpoint = torch.load( os.path.join(MODEL_DIR, "classifier.pt"), map_location="cpu", weights_only=True, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = AIFinderNet( input_dim=checkpoint["input_dim"], num_providers=checkpoint["num_providers"], hidden_dim=checkpoint["hidden_dim"], embed_dim=checkpoint["embed_dim"], dropout=checkpoint["dropout"], ).to(device) net.load_state_dict(checkpoint["state_dict"], strict=False) net.eval() return pipeline, net, provider_enc, checkpoint, device except FileNotFoundError: print(f"Error: Models not found in {MODEL_DIR}") print(f"Run 'python3 train.py' first to train the models.") sys.exit(1) def classify_text(text, pipeline, net, provider_enc, device): """Classify a single text and return provider results.""" t0 = time.time() X = pipeline.transform([text]) X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device) print(f" (featurize: {time.time() - t0:.2f}s)", end="") with torch.no_grad(): prov_logits = net(X_t) prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy() # Provider top-5 top_prov_idxs = np.argsort(prov_proba)[::-1][:5] top_providers = [ (provider_enc.inverse_transform([i])[0], prov_proba[i] * 100) for i in top_prov_idxs ] elapsed = time.time() - t0 print(f" (total classify: {elapsed:.2f}s)") return { "provider": top_providers[0][0], "provider_confidence": top_providers[0][1], "top_providers": top_providers, } def print_results(results): """Pretty-print classification results.""" print() print(" ┌───────────────────────────────────────────────┐") print( f" │ Provider: {results['provider']} ({results['provider_confidence']:.1f}%)" ) for name, conf in results["top_providers"]: c = 0.0 if np.isnan(conf) else conf bar = "█" * int(c / 5) + "░" * (20 - int(c / 5)) print(f" │ {name:.<25s} {c:5.1f}% {bar}") print(" └───────────────────────────────────────────────┘") print() def correct_provider( net, X_t, correct_provider_name, provider_enc, optimizer, device, ): """Do a backward pass to correct the provider on a single example.""" try: prov_idx = provider_enc.transform([correct_provider_name])[0] except ValueError as e: print(f" (label not in encoder: {e})") return False y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device) was_training = net.training net.train() # Disable batchnorm for single-sample training if X_t.shape[0] <= 1: for module in net.modules(): if isinstance(module, nn.modules.batchnorm._BatchNorm): module.eval() optimizer.zero_grad(set_to_none=True) prov_criterion = nn.CrossEntropyLoss() prov_logits = net(X_t) loss = prov_criterion(prov_logits, y_prov) loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) optimizer.step() if was_training: net.train() else: net.eval() print(f" ✓ Corrected → {correct_provider_name} (loss={loss.item():.4f})") return True def prompt_correction(known_providers): """Ask user for the correct provider.""" print(" Wrong? Enter correct provider number (or Enter to skip):") for i, name in enumerate(known_providers, 1): print(f" {i:>2d}. {name}") try: prov_choice = input(" Provider > ").strip() except EOFError: return None if not prov_choice: return None correct_provider = None try: idx = int(prov_choice) - 1 if 0 <= idx < len(known_providers): correct_provider = known_providers[idx] except ValueError: matches = [m for m in known_providers if prov_choice.lower() in m.lower()] if len(matches) == 1: correct_provider = matches[0] if not correct_provider: print(" (invalid choice, skipping)") return None return correct_provider def main(): print() print(" ╔═══════════════════════════════════════╗") print(" ║ AIFinder - AI Response Classifier ║") print(" ╚═══════════════════════════════════════╝") print() print(" Loading models...") t0 = time.time() pipeline, net, provider_enc, checkpoint, device = load_models() print(f" Models loaded in {time.time() - t0:.1f}s.") # Prepare online learning components optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4) known_providers = sorted(provider_enc.classes_.tolist()) corrections_made = 0 print() print(" Paste text to classify (submit with TWO empty lines).") print(" Type 'quit' to exit.\n") last_X_t = None while True: print(" ─── Paste text below ───") lines = [] empty_count = 0 while True: try: line = input() except EOFError: break if line.strip() == "": empty_count += 1 if empty_count >= 2: break lines.append(line) else: empty_count = 0 if line.strip().lower() == "quit": if corrections_made > 0: print( f" Saving {corrections_made} correction(s) to checkpoint..." ) checkpoint["state_dict"] = net.state_dict() torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt")) print(" ✓ Saved.") print(" Goodbye!") return lines.append(line) text = "\n".join(lines).strip() if not text: print(" (empty input, try again)") continue if len(text) < 20: print(" (text too short, need at least 20 chars)") continue results = classify_text(text, pipeline, net, provider_enc, device) print_results(results) X = pipeline.transform([text]) last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device) correct_prov = prompt_correction(known_providers) if correct_prov: ok = correct_provider( net, last_X_t, correct_prov, provider_enc, optimizer, device, ) if ok: corrections_made += 1 if __name__ == "__main__": main()