| | """ |
| | AIFinder Interactive Classifier |
| | Loads trained model and provides an interactive REPL for classifying text. |
| | |
| | Usage: python3 classify.py |
| | """ |
| |
|
| | import os |
| | import sys |
| | import time |
| | import joblib |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS |
| | from model import AIFinderNet |
| |
|
| |
|
| | def load_models(): |
| | """Load all model components from the model directory.""" |
| | try: |
| | pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib")) |
| | provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib")) |
| |
|
| | checkpoint = torch.load( |
| | os.path.join(MODEL_DIR, "classifier.pt"), |
| | map_location="cpu", |
| | weights_only=True, |
| | ) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | net = AIFinderNet( |
| | input_dim=checkpoint["input_dim"], |
| | num_providers=checkpoint["num_providers"], |
| | hidden_dim=checkpoint["hidden_dim"], |
| | embed_dim=checkpoint["embed_dim"], |
| | dropout=checkpoint["dropout"], |
| | ).to(device) |
| | net.load_state_dict(checkpoint["state_dict"], strict=False) |
| | net.eval() |
| |
|
| | return pipeline, net, provider_enc, checkpoint, device |
| | except FileNotFoundError: |
| | print(f"Error: Models not found in {MODEL_DIR}") |
| | print(f"Run 'python3 train.py' first to train the models.") |
| | sys.exit(1) |
| |
|
| |
|
| | def classify_text(text, pipeline, net, provider_enc, device): |
| | """Classify a single text and return provider results.""" |
| | t0 = time.time() |
| | X = pipeline.transform([text]) |
| | X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device) |
| | print(f" (featurize: {time.time() - t0:.2f}s)", end="") |
| |
|
| | with torch.no_grad(): |
| | prov_logits = net(X_t) |
| |
|
| | prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy() |
| |
|
| | |
| | top_prov_idxs = np.argsort(prov_proba)[::-1][:5] |
| | top_providers = [ |
| | (provider_enc.inverse_transform([i])[0], prov_proba[i] * 100) |
| | for i in top_prov_idxs |
| | ] |
| |
|
| | elapsed = time.time() - t0 |
| | print(f" (total classify: {elapsed:.2f}s)") |
| |
|
| | return { |
| | "provider": top_providers[0][0], |
| | "provider_confidence": top_providers[0][1], |
| | "top_providers": top_providers, |
| | } |
| |
|
| |
|
| | def print_results(results): |
| | """Pretty-print classification results.""" |
| | print() |
| | print(" βββββββββββββββββββββββββββββββββββββββββββββββββ") |
| | print( |
| | f" β Provider: {results['provider']} ({results['provider_confidence']:.1f}%)" |
| | ) |
| | for name, conf in results["top_providers"]: |
| | c = 0.0 if np.isnan(conf) else conf |
| | bar = "β" * int(c / 5) + "β" * (20 - int(c / 5)) |
| | print(f" β {name:.<25s} {c:5.1f}% {bar}") |
| |
|
| | print(" βββββββββββββββββββββββββββββββββββββββββββββββββ") |
| | print() |
| |
|
| |
|
| | def correct_provider( |
| | net, |
| | X_t, |
| | correct_provider_name, |
| | provider_enc, |
| | optimizer, |
| | device, |
| | ): |
| | """Do a backward pass to correct the provider on a single example.""" |
| | try: |
| | prov_idx = provider_enc.transform([correct_provider_name])[0] |
| | except ValueError as e: |
| | print(f" (label not in encoder: {e})") |
| | return False |
| |
|
| | y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device) |
| |
|
| | was_training = net.training |
| | net.train() |
| |
|
| | |
| | if X_t.shape[0] <= 1: |
| | for module in net.modules(): |
| | if isinstance(module, nn.modules.batchnorm._BatchNorm): |
| | module.eval() |
| |
|
| | optimizer.zero_grad(set_to_none=True) |
| | prov_criterion = nn.CrossEntropyLoss() |
| |
|
| | prov_logits = net(X_t) |
| | loss = prov_criterion(prov_logits, y_prov) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0) |
| | optimizer.step() |
| |
|
| | if was_training: |
| | net.train() |
| | else: |
| | net.eval() |
| |
|
| | print(f" β Corrected β {correct_provider_name} (loss={loss.item():.4f})") |
| | return True |
| |
|
| |
|
| | def prompt_correction(known_providers): |
| | """Ask user for the correct provider.""" |
| | print(" Wrong? Enter correct provider number (or Enter to skip):") |
| | for i, name in enumerate(known_providers, 1): |
| | print(f" {i:>2d}. {name}") |
| | try: |
| | prov_choice = input(" Provider > ").strip() |
| | except EOFError: |
| | return None |
| | if not prov_choice: |
| | return None |
| |
|
| | correct_provider = None |
| | try: |
| | idx = int(prov_choice) - 1 |
| | if 0 <= idx < len(known_providers): |
| | correct_provider = known_providers[idx] |
| | except ValueError: |
| | matches = [m for m in known_providers if prov_choice.lower() in m.lower()] |
| | if len(matches) == 1: |
| | correct_provider = matches[0] |
| |
|
| | if not correct_provider: |
| | print(" (invalid choice, skipping)") |
| | return None |
| |
|
| | return correct_provider |
| |
|
| |
|
| | def main(): |
| | print() |
| | print(" βββββββββββββββββββββββββββββββββββββββββ") |
| | print(" β AIFinder - AI Response Classifier β") |
| | print(" βββββββββββββββββββββββββββββββββββββββββ") |
| | print() |
| |
|
| | print(" Loading models...") |
| | t0 = time.time() |
| | pipeline, net, provider_enc, checkpoint, device = load_models() |
| | print(f" Models loaded in {time.time() - t0:.1f}s.") |
| |
|
| | |
| | optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4) |
| | known_providers = sorted(provider_enc.classes_.tolist()) |
| | corrections_made = 0 |
| |
|
| | print() |
| | print(" Paste text to classify (submit with TWO empty lines).") |
| | print(" Type 'quit' to exit.\n") |
| |
|
| | last_X_t = None |
| |
|
| | while True: |
| | print(" βββ Paste text below βββ") |
| | lines = [] |
| | empty_count = 0 |
| | while True: |
| | try: |
| | line = input() |
| | except EOFError: |
| | break |
| | if line.strip() == "": |
| | empty_count += 1 |
| | if empty_count >= 2: |
| | break |
| | lines.append(line) |
| | else: |
| | empty_count = 0 |
| | if line.strip().lower() == "quit": |
| | if corrections_made > 0: |
| | print( |
| | f" Saving {corrections_made} correction(s) to checkpoint..." |
| | ) |
| | checkpoint["state_dict"] = net.state_dict() |
| | torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt")) |
| | print(" β Saved.") |
| | print(" Goodbye!") |
| | return |
| | lines.append(line) |
| |
|
| | text = "\n".join(lines).strip() |
| | if not text: |
| | print(" (empty input, try again)") |
| | continue |
| |
|
| | if len(text) < 20: |
| | print(" (text too short, need at least 20 chars)") |
| | continue |
| |
|
| | results = classify_text(text, pipeline, net, provider_enc, device) |
| | print_results(results) |
| |
|
| | X = pipeline.transform([text]) |
| | last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device) |
| |
|
| | correct_prov = prompt_correction(known_providers) |
| | if correct_prov: |
| | ok = correct_provider( |
| | net, |
| | last_X_t, |
| | correct_prov, |
| | provider_enc, |
| | optimizer, |
| | device, |
| | ) |
| | if ok: |
| | corrections_made += 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|