AIFinder / classify.py
CompactAI's picture
Upload 13 files
f52234e verified
"""
AIFinder Interactive Classifier
Loads trained model and provides an interactive REPL for classifying text.
Usage: python3 classify.py
"""
import os
import sys
import time
import joblib
import numpy as np
import torch
import torch.nn as nn
from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS
from model import AIFinderNet
def load_models():
"""Load all model components from the model directory."""
try:
pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib"))
checkpoint = torch.load(
os.path.join(MODEL_DIR, "classifier.pt"),
map_location="cpu",
weights_only=True,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = AIFinderNet(
input_dim=checkpoint["input_dim"],
num_providers=checkpoint["num_providers"],
hidden_dim=checkpoint["hidden_dim"],
embed_dim=checkpoint["embed_dim"],
dropout=checkpoint["dropout"],
).to(device)
net.load_state_dict(checkpoint["state_dict"], strict=False)
net.eval()
return pipeline, net, provider_enc, checkpoint, device
except FileNotFoundError:
print(f"Error: Models not found in {MODEL_DIR}")
print(f"Run 'python3 train.py' first to train the models.")
sys.exit(1)
def classify_text(text, pipeline, net, provider_enc, device):
"""Classify a single text and return provider results."""
t0 = time.time()
X = pipeline.transform([text])
X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
print(f" (featurize: {time.time() - t0:.2f}s)", end="")
with torch.no_grad():
prov_logits = net(X_t)
prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
# Provider top-5
top_prov_idxs = np.argsort(prov_proba)[::-1][:5]
top_providers = [
(provider_enc.inverse_transform([i])[0], prov_proba[i] * 100)
for i in top_prov_idxs
]
elapsed = time.time() - t0
print(f" (total classify: {elapsed:.2f}s)")
return {
"provider": top_providers[0][0],
"provider_confidence": top_providers[0][1],
"top_providers": top_providers,
}
def print_results(results):
"""Pretty-print classification results."""
print()
print(" β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(
f" β”‚ Provider: {results['provider']} ({results['provider_confidence']:.1f}%)"
)
for name, conf in results["top_providers"]:
c = 0.0 if np.isnan(conf) else conf
bar = "β–ˆ" * int(c / 5) + "β–‘" * (20 - int(c / 5))
print(f" β”‚ {name:.<25s} {c:5.1f}% {bar}")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print()
def correct_provider(
net,
X_t,
correct_provider_name,
provider_enc,
optimizer,
device,
):
"""Do a backward pass to correct the provider on a single example."""
try:
prov_idx = provider_enc.transform([correct_provider_name])[0]
except ValueError as e:
print(f" (label not in encoder: {e})")
return False
y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device)
was_training = net.training
net.train()
# Disable batchnorm for single-sample training
if X_t.shape[0] <= 1:
for module in net.modules():
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.eval()
optimizer.zero_grad(set_to_none=True)
prov_criterion = nn.CrossEntropyLoss()
prov_logits = net(X_t)
loss = prov_criterion(prov_logits, y_prov)
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
optimizer.step()
if was_training:
net.train()
else:
net.eval()
print(f" βœ“ Corrected β†’ {correct_provider_name} (loss={loss.item():.4f})")
return True
def prompt_correction(known_providers):
"""Ask user for the correct provider."""
print(" Wrong? Enter correct provider number (or Enter to skip):")
for i, name in enumerate(known_providers, 1):
print(f" {i:>2d}. {name}")
try:
prov_choice = input(" Provider > ").strip()
except EOFError:
return None
if not prov_choice:
return None
correct_provider = None
try:
idx = int(prov_choice) - 1
if 0 <= idx < len(known_providers):
correct_provider = known_providers[idx]
except ValueError:
matches = [m for m in known_providers if prov_choice.lower() in m.lower()]
if len(matches) == 1:
correct_provider = matches[0]
if not correct_provider:
print(" (invalid choice, skipping)")
return None
return correct_provider
def main():
print()
print(" ╔═══════════════════════════════════════╗")
print(" β•‘ AIFinder - AI Response Classifier β•‘")
print(" β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
print()
print(" Loading models...")
t0 = time.time()
pipeline, net, provider_enc, checkpoint, device = load_models()
print(f" Models loaded in {time.time() - t0:.1f}s.")
# Prepare online learning components
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4)
known_providers = sorted(provider_enc.classes_.tolist())
corrections_made = 0
print()
print(" Paste text to classify (submit with TWO empty lines).")
print(" Type 'quit' to exit.\n")
last_X_t = None
while True:
print(" ─── Paste text below ───")
lines = []
empty_count = 0
while True:
try:
line = input()
except EOFError:
break
if line.strip() == "":
empty_count += 1
if empty_count >= 2:
break
lines.append(line)
else:
empty_count = 0
if line.strip().lower() == "quit":
if corrections_made > 0:
print(
f" Saving {corrections_made} correction(s) to checkpoint..."
)
checkpoint["state_dict"] = net.state_dict()
torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
print(" βœ“ Saved.")
print(" Goodbye!")
return
lines.append(line)
text = "\n".join(lines).strip()
if not text:
print(" (empty input, try again)")
continue
if len(text) < 20:
print(" (text too short, need at least 20 chars)")
continue
results = classify_text(text, pipeline, net, provider_enc, device)
print_results(results)
X = pipeline.transform([text])
last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
correct_prov = prompt_correction(known_providers)
if correct_prov:
ok = correct_provider(
net,
last_X_t,
correct_prov,
provider_enc,
optimizer,
device,
)
if ok:
corrections_made += 1
if __name__ == "__main__":
main()