CompactAI commited on
Commit
f52234e
·
verified ·
1 Parent(s): 2fc7a6d

Upload 13 files

Browse files
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder API Server
3
+ Serves classification and training endpoints for the frontend.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import joblib
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from flask import Flask, request, jsonify, send_from_directory
14
+ from flask_cors import CORS
15
+
16
+ from config import MODEL_DIR
17
+ from model import AIFinderNet
18
+ from features import FeaturePipeline
19
+
20
+ app = Flask(__name__, static_folder="static", static_url_path="")
21
+ CORS(app)
22
+
23
+ pipeline = None
24
+ provider_enc = None
25
+ net = None
26
+ device = None
27
+ checkpoint = None
28
+
29
+
30
+ def load_models():
31
+ global pipeline, provider_enc, net, device, checkpoint
32
+
33
+ pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
34
+ provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib"))
35
+
36
+ checkpoint = torch.load(
37
+ os.path.join(MODEL_DIR, "classifier.pt"),
38
+ map_location="cpu",
39
+ weights_only=True,
40
+ )
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ net = AIFinderNet(
43
+ input_dim=checkpoint["input_dim"],
44
+ num_providers=checkpoint["num_providers"],
45
+ hidden_dim=checkpoint["hidden_dim"],
46
+ embed_dim=checkpoint["embed_dim"],
47
+ dropout=checkpoint["dropout"],
48
+ ).to(device)
49
+ net.load_state_dict(checkpoint["state_dict"], strict=False)
50
+ net.eval()
51
+
52
+
53
+ @app.route("/")
54
+ def index():
55
+ return send_from_directory("static", "index.html")
56
+
57
+
58
+ @app.route("/api/providers", methods=["GET"])
59
+ def get_providers():
60
+ """Return list of available providers."""
61
+ return jsonify({"providers": sorted(provider_enc.classes_.tolist())})
62
+
63
+
64
+ @app.route("/api/classify", methods=["POST"])
65
+ def classify():
66
+ """Classify text and return provider predictions."""
67
+ data = request.json
68
+ text = data.get("text", "")
69
+
70
+ if len(text) < 20:
71
+ return jsonify({"error": "Text too short (minimum 20 characters)"}), 400
72
+
73
+ X = pipeline.transform([text])
74
+ X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
75
+
76
+ with torch.no_grad():
77
+ prov_logits = net(X_t)
78
+
79
+ prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
80
+
81
+ top_prov_idxs = np.argsort(prov_proba)[::-1][:5]
82
+ top_providers = [
83
+ {
84
+ "name": provider_enc.inverse_transform([i])[0],
85
+ "confidence": float(prov_proba[i] * 100),
86
+ }
87
+ for i in top_prov_idxs
88
+ ]
89
+
90
+ return jsonify(
91
+ {
92
+ "provider": top_providers[0]["name"],
93
+ "confidence": top_providers[0]["confidence"],
94
+ "top_providers": top_providers,
95
+ }
96
+ )
97
+
98
+
99
+ @app.route("/api/correct", methods=["POST"])
100
+ def correct():
101
+ """Train on a corrected example."""
102
+ data = request.json
103
+ text = data.get("text", "")
104
+ correct_provider = data.get("correct_provider", "")
105
+
106
+ if not text or not correct_provider:
107
+ return jsonify({"error": "Missing text or correct_provider"}), 400
108
+
109
+ try:
110
+ prov_idx = provider_enc.transform([correct_provider])[0]
111
+ except ValueError as e:
112
+ return jsonify({"error": f"Unknown provider: {e}"}), 400
113
+
114
+ X = pipeline.transform([text])
115
+ X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
116
+ y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device)
117
+
118
+ net.train()
119
+ for module in net.modules():
120
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
121
+ module.eval()
122
+
123
+ optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4)
124
+ optimizer.zero_grad(set_to_none=True)
125
+
126
+ prov_criterion = nn.CrossEntropyLoss()
127
+ prov_logits = net(X_t)
128
+ loss = prov_criterion(prov_logits, y_prov)
129
+ loss.backward()
130
+ torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
131
+ optimizer.step()
132
+
133
+ net.eval()
134
+
135
+ checkpoint["state_dict"] = net.state_dict()
136
+
137
+ return jsonify({"success": True, "loss": float(loss.item())})
138
+
139
+
140
+ @app.route("/api/save", methods=["POST"])
141
+ def save_model():
142
+ """Save the current model state to a file for export."""
143
+ global checkpoint
144
+ data = request.json
145
+ filename = data.get("filename", "aifinder_model.pt")
146
+
147
+ save_path = os.path.join(MODEL_DIR, filename)
148
+ torch.save(checkpoint, save_path)
149
+
150
+ return jsonify({"success": True, "filename": filename})
151
+
152
+
153
+ @app.route("/models/<filename>")
154
+ def download_model(filename):
155
+ """Download exported model file."""
156
+ return send_from_directory(MODEL_DIR, filename)
157
+
158
+
159
+ @app.route("/api/status", methods=["GET"])
160
+ def status():
161
+ """Check if models are loaded."""
162
+ return jsonify(
163
+ {
164
+ "loaded": net is not None,
165
+ "device": str(device) if device else None,
166
+ }
167
+ )
168
+
169
+
170
+ if __name__ == "__main__":
171
+ print("Loading models...")
172
+ load_models()
173
+ print(f"Ready on {device}")
174
+ app.run(host="0.0.0.0", port=7860)
classify.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Interactive Classifier
3
+ Loads trained model and provides an interactive REPL for classifying text.
4
+
5
+ Usage: python3 classify.py
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ import joblib
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS
17
+ from model import AIFinderNet
18
+
19
+
20
+ def load_models():
21
+ """Load all model components from the model directory."""
22
+ try:
23
+ pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
24
+ provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib"))
25
+
26
+ checkpoint = torch.load(
27
+ os.path.join(MODEL_DIR, "classifier.pt"),
28
+ map_location="cpu",
29
+ weights_only=True,
30
+ )
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ net = AIFinderNet(
33
+ input_dim=checkpoint["input_dim"],
34
+ num_providers=checkpoint["num_providers"],
35
+ hidden_dim=checkpoint["hidden_dim"],
36
+ embed_dim=checkpoint["embed_dim"],
37
+ dropout=checkpoint["dropout"],
38
+ ).to(device)
39
+ net.load_state_dict(checkpoint["state_dict"], strict=False)
40
+ net.eval()
41
+
42
+ return pipeline, net, provider_enc, checkpoint, device
43
+ except FileNotFoundError:
44
+ print(f"Error: Models not found in {MODEL_DIR}")
45
+ print(f"Run 'python3 train.py' first to train the models.")
46
+ sys.exit(1)
47
+
48
+
49
+ def classify_text(text, pipeline, net, provider_enc, device):
50
+ """Classify a single text and return provider results."""
51
+ t0 = time.time()
52
+ X = pipeline.transform([text])
53
+ X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
54
+ print(f" (featurize: {time.time() - t0:.2f}s)", end="")
55
+
56
+ with torch.no_grad():
57
+ prov_logits = net(X_t)
58
+
59
+ prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
60
+
61
+ # Provider top-5
62
+ top_prov_idxs = np.argsort(prov_proba)[::-1][:5]
63
+ top_providers = [
64
+ (provider_enc.inverse_transform([i])[0], prov_proba[i] * 100)
65
+ for i in top_prov_idxs
66
+ ]
67
+
68
+ elapsed = time.time() - t0
69
+ print(f" (total classify: {elapsed:.2f}s)")
70
+
71
+ return {
72
+ "provider": top_providers[0][0],
73
+ "provider_confidence": top_providers[0][1],
74
+ "top_providers": top_providers,
75
+ }
76
+
77
+
78
+ def print_results(results):
79
+ """Pretty-print classification results."""
80
+ print()
81
+ print(" ┌───────────────────────────────────────────────┐")
82
+ print(
83
+ f" │ Provider: {results['provider']} ({results['provider_confidence']:.1f}%)"
84
+ )
85
+ for name, conf in results["top_providers"]:
86
+ c = 0.0 if np.isnan(conf) else conf
87
+ bar = "█" * int(c / 5) + "░" * (20 - int(c / 5))
88
+ print(f" │ {name:.<25s} {c:5.1f}% {bar}")
89
+
90
+ print(" └───────────────────────────────────────────────┘")
91
+ print()
92
+
93
+
94
+ def correct_provider(
95
+ net,
96
+ X_t,
97
+ correct_provider_name,
98
+ provider_enc,
99
+ optimizer,
100
+ device,
101
+ ):
102
+ """Do a backward pass to correct the provider on a single example."""
103
+ try:
104
+ prov_idx = provider_enc.transform([correct_provider_name])[0]
105
+ except ValueError as e:
106
+ print(f" (label not in encoder: {e})")
107
+ return False
108
+
109
+ y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device)
110
+
111
+ was_training = net.training
112
+ net.train()
113
+
114
+ # Disable batchnorm for single-sample training
115
+ if X_t.shape[0] <= 1:
116
+ for module in net.modules():
117
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
118
+ module.eval()
119
+
120
+ optimizer.zero_grad(set_to_none=True)
121
+ prov_criterion = nn.CrossEntropyLoss()
122
+
123
+ prov_logits = net(X_t)
124
+ loss = prov_criterion(prov_logits, y_prov)
125
+ loss.backward()
126
+ torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
127
+ optimizer.step()
128
+
129
+ if was_training:
130
+ net.train()
131
+ else:
132
+ net.eval()
133
+
134
+ print(f" ✓ Corrected → {correct_provider_name} (loss={loss.item():.4f})")
135
+ return True
136
+
137
+
138
+ def prompt_correction(known_providers):
139
+ """Ask user for the correct provider."""
140
+ print(" Wrong? Enter correct provider number (or Enter to skip):")
141
+ for i, name in enumerate(known_providers, 1):
142
+ print(f" {i:>2d}. {name}")
143
+ try:
144
+ prov_choice = input(" Provider > ").strip()
145
+ except EOFError:
146
+ return None
147
+ if not prov_choice:
148
+ return None
149
+
150
+ correct_provider = None
151
+ try:
152
+ idx = int(prov_choice) - 1
153
+ if 0 <= idx < len(known_providers):
154
+ correct_provider = known_providers[idx]
155
+ except ValueError:
156
+ matches = [m for m in known_providers if prov_choice.lower() in m.lower()]
157
+ if len(matches) == 1:
158
+ correct_provider = matches[0]
159
+
160
+ if not correct_provider:
161
+ print(" (invalid choice, skipping)")
162
+ return None
163
+
164
+ return correct_provider
165
+
166
+
167
+ def main():
168
+ print()
169
+ print(" ╔═══════════════════════════════════════╗")
170
+ print(" ║ AIFinder - AI Response Classifier ║")
171
+ print(" ╚═══════════════════════════════════════╝")
172
+ print()
173
+
174
+ print(" Loading models...")
175
+ t0 = time.time()
176
+ pipeline, net, provider_enc, checkpoint, device = load_models()
177
+ print(f" Models loaded in {time.time() - t0:.1f}s.")
178
+
179
+ # Prepare online learning components
180
+ optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4)
181
+ known_providers = sorted(provider_enc.classes_.tolist())
182
+ corrections_made = 0
183
+
184
+ print()
185
+ print(" Paste text to classify (submit with TWO empty lines).")
186
+ print(" Type 'quit' to exit.\n")
187
+
188
+ last_X_t = None
189
+
190
+ while True:
191
+ print(" ─── Paste text below ───")
192
+ lines = []
193
+ empty_count = 0
194
+ while True:
195
+ try:
196
+ line = input()
197
+ except EOFError:
198
+ break
199
+ if line.strip() == "":
200
+ empty_count += 1
201
+ if empty_count >= 2:
202
+ break
203
+ lines.append(line)
204
+ else:
205
+ empty_count = 0
206
+ if line.strip().lower() == "quit":
207
+ if corrections_made > 0:
208
+ print(
209
+ f" Saving {corrections_made} correction(s) to checkpoint..."
210
+ )
211
+ checkpoint["state_dict"] = net.state_dict()
212
+ torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
213
+ print(" ✓ Saved.")
214
+ print(" Goodbye!")
215
+ return
216
+ lines.append(line)
217
+
218
+ text = "\n".join(lines).strip()
219
+ if not text:
220
+ print(" (empty input, try again)")
221
+ continue
222
+
223
+ if len(text) < 20:
224
+ print(" (text too short, need at least 20 chars)")
225
+ continue
226
+
227
+ results = classify_text(text, pipeline, net, provider_enc, device)
228
+ print_results(results)
229
+
230
+ X = pipeline.transform([text])
231
+ last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
232
+
233
+ correct_prov = prompt_correction(known_providers)
234
+ if correct_prov:
235
+ ok = correct_provider(
236
+ net,
237
+ last_X_t,
238
+ correct_prov,
239
+ provider_enc,
240
+ optimizer,
241
+ device,
242
+ )
243
+ if ok:
244
+ corrections_made += 1
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
config.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Configuration
3
+ Dataset registry, label mappings, and feature parameters.
4
+ """
5
+
6
+ import os
7
+
8
+ # --- Paths ---
9
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
+ MODEL_DIR = os.path.join(BASE_DIR, "models")
11
+
12
+ # --- Dataset Registry ---
13
+ # Each entry: (hf_dataset_id, provider, model_name, optional_kwargs)
14
+ # optional_kwargs: subset name, split, etc.
15
+ DATASET_REGISTRY = [
16
+ # Anthropic
17
+ ("TeichAI/claude-4.5-opus-high-reasoning-250x", "Anthropic", "Claude 4.5 Opus", {}),
18
+ ("TeichAI/claude-sonnet-4.5-high-reasoning-250x", "Anthropic", "Claude Sonnet 4.5", {}),
19
+ ("Roman1111111/claude-opus-4.6-10000x", "Anthropic", "Claude Opus 4.6", {"max_samples": 1500}),
20
+
21
+ # OpenAI
22
+ ("TeichAI/gpt-5.2-high-reasoning-250x", "OpenAI", "GPT-5.2", {}),
23
+ ("TeichAI/gpt-5.1-high-reasoning-1000x", "OpenAI", "GPT-5.1", {}),
24
+ ("TeichAI/gpt-5.1-codex-max-1000x", "OpenAI", "GPT-5.1 Codex Max", {}),
25
+ ("TeichAI/gpt-5-codex-250x", "OpenAI", "GPT-5 Codex", {}),
26
+ ("TeichAI/gpt-5-codex-1000x", "OpenAI", "GPT-5 Codex", {}),
27
+
28
+ # Google
29
+ ("TeichAI/gemini-3-pro-preview-high-reasoning-1000x", "Google", "Gemini 3 Pro", {}),
30
+ ("TeichAI/gemini-3-pro-preview-high-reasoning-250x", "Google", "Gemini 3 Pro", {}),
31
+ ("TeichAI/gemini-2.5-flash-11000x", "Google", "Gemini 2.5 Flash", {"max_samples": 1500}),
32
+ ("TeichAI/Gemini-3-Flash-Preview-VIBE", "Google", "Gemini 3 Flash", {}),
33
+ ("TeichAI/gemini-3-flash-preview-1000x", "Google", "Gemini 3 Flash", {}),
34
+ ("TeichAI/gemini-3-flash-preview-complex-1000x", "Google", "Gemini 3 Flash", {}),
35
+
36
+ # xAI
37
+ ("TeichAI/brainstorm-v3.1-grok-4-fast-200x", "xAI", "Grok 4 Fast", {}),
38
+ ("TeichAI/sherlock-thinking-alpha-11000x", "xAI", "Grok 4.1 Fast", {"max_samples": 1500}),
39
+ ("TeichAI/sherlock-dash-alpha-1000x", "xAI", "Grok 4.1 Fast", {}),
40
+ ("TeichAI/sherlock-think-alpha-1000x", "xAI", "Grok 4.1 Fast", {}),
41
+ ("TeichAI/grok-code-fast-1-1000x", "xAI", "Grok Code Fast 1", {}),
42
+
43
+ # MoonshotAI
44
+ ("TeichAI/kimi-k2-thinking-250x", "MoonshotAI", "Kimi K2", {}),
45
+ ("TeichAI/kimi-k2-thinking-1000x", "MoonshotAI", "Kimi K2", {}),
46
+
47
+ # Mistral
48
+ ("TeichAI/mistral-small-creative-500x", "Mistral", "Mistral Small", {}),
49
+
50
+ # MiniMax
51
+ ("TeichAI/MiniMax-M2.1-Code-SFT", "MiniMax", "MiniMax M2.1", {}),
52
+ ("TeichAI/convo-v1", "MiniMax", "MiniMax M2.1", {}),
53
+
54
+ # StepFun
55
+ ("TeichAI/Step-3.5-Flash-2600x", "StepFun", "Step 3.5 Flash", {"max_samples": 1500}),
56
+
57
+ # Zhipu
58
+ ("TeichAI/Pony-Alpha-15k", "Zhipu", "GLM-5", {"max_samples": 1500}),
59
+
60
+ # DeepSeek (TeichAI)
61
+ ("TeichAI/deepseek-v3.2-speciale-1000x", "DeepSeek", "DeepSeek V3.2 Speciale", {}),
62
+ ("TeichAI/deepseek-v3.2-speciale-openr1-math-3k", "DeepSeek", "DeepSeek V3.2 Speciale", {"max_samples": 1500}),
63
+ ]
64
+
65
+ # DeepSeek (a-m-team) — different format, handled separately
66
+ DEEPSEEK_AM_DATASETS = [
67
+ ("a-m-team/AM-DeepSeek-R1-Distilled-1.4M", "DeepSeek", "DeepSeek R1", {"name": "am_0.9M_sample_1k", "max_samples": 1000}),
68
+ ]
69
+
70
+ # --- All providers and models ---
71
+ PROVIDERS = [
72
+ "Anthropic", "OpenAI", "Google", "xAI", "MoonshotAI",
73
+ "Mistral", "MiniMax", "StepFun", "Zhipu", "DeepSeek"
74
+ ]
75
+
76
+ # --- Feature parameters ---
77
+ TFIDF_WORD_PARAMS = {
78
+ "analyzer": "word",
79
+ "ngram_range": (1, 2),
80
+ "max_features": 20000,
81
+ "sublinear_tf": True,
82
+ "min_df": 3,
83
+ }
84
+
85
+ TFIDF_CHAR_PARAMS = {
86
+ "analyzer": "char_wb",
87
+ "ngram_range": (3, 5),
88
+ "max_features": 20000,
89
+ "sublinear_tf": True,
90
+ "min_df": 3,
91
+ }
92
+
93
+ # --- Train/test split ---
94
+ TEST_SIZE = 0.2
95
+ RANDOM_STATE = 42
96
+
97
+ # --- Neural Network ---
98
+ HIDDEN_DIM = 1024
99
+ EMBED_DIM = 256
100
+ DROPOUT = 0.3
101
+ BATCH_SIZE = 2048
102
+ EPOCHS = 50
103
+ EARLY_STOP_PATIENCE = 8
104
+ LEARNING_RATE = 1e-3
105
+ WEIGHT_DECAY = 1e-4
data_loader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Data Loader
3
+ Downloads and parses HuggingFace datasets, extracts assistant responses,
4
+ and labels them with is_ai, provider, and model.
5
+ """
6
+
7
+ import re
8
+ import time
9
+ from datasets import load_dataset
10
+ from tqdm import tqdm
11
+
12
+ from config import (
13
+ DATASET_REGISTRY,
14
+ DEEPSEEK_AM_DATASETS,
15
+ )
16
+
17
+
18
+ def _parse_msg(msg):
19
+ """Parse a message that may be a dict or a JSON string."""
20
+ if isinstance(msg, dict):
21
+ return msg
22
+ if isinstance(msg, str):
23
+ try:
24
+ import json
25
+
26
+ parsed = json.loads(msg)
27
+ if isinstance(parsed, dict):
28
+ return parsed
29
+ except (json.JSONDecodeError, ValueError):
30
+ pass
31
+ return {}
32
+
33
+
34
+ def _extract_assistant_texts_from_conversations(rows):
35
+ """Extract assistant message content from conversation datasets.
36
+ These have a 'conversations' or 'messages' column with list of
37
+ {role, content} dicts (or JSON strings encoding such dicts).
38
+ """
39
+ texts = []
40
+ for row in rows:
41
+ convos = row.get("conversations")
42
+ if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
43
+ convos = row.get("messages")
44
+ if convos is None or (hasattr(convos, "__len__") and len(convos) == 0):
45
+ convos = []
46
+ parts = []
47
+ for msg in convos:
48
+ msg = _parse_msg(msg)
49
+ role = msg.get("role", "")
50
+ content = msg.get("content", "")
51
+ if role in ("assistant", "gpt", "model") and content:
52
+ parts.append(content)
53
+ if parts:
54
+ texts.append("\n\n".join(parts))
55
+ return texts
56
+
57
+
58
+ def _extract_from_am_dataset(row):
59
+ """Extract assistant text from a-m-team format (messages list with role/content)."""
60
+ messages = row.get("messages") or row.get("conversations") or []
61
+ parts = []
62
+ for msg in messages:
63
+ role = msg.get("role", "") if isinstance(msg, dict) else ""
64
+ content = msg.get("content", "") if isinstance(msg, dict) else ""
65
+ if role == "assistant" and content:
66
+ parts.append(content)
67
+ return "\n\n".join(parts) if parts else ""
68
+
69
+
70
+ def load_teichai_dataset(dataset_id, provider, model_name, kwargs):
71
+ """Load a single conversation-format dataset and return (texts, providers, models)."""
72
+ max_samples = kwargs.get("max_samples")
73
+ load_kwargs = {}
74
+ if "name" in kwargs:
75
+ load_kwargs["name"] = kwargs["name"]
76
+
77
+ try:
78
+ ds = load_dataset(dataset_id, split="train", **load_kwargs)
79
+ rows = list(ds)
80
+ except Exception as e:
81
+ # Fallback: load from auto-converted parquet via HF API
82
+ try:
83
+ import pandas as pd
84
+
85
+ url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet"
86
+ df = pd.read_parquet(url)
87
+ rows = df.to_dict(orient="records")
88
+ except Exception as e2:
89
+ print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}")
90
+ return [], [], []
91
+
92
+ if max_samples and len(rows) > max_samples:
93
+ import random
94
+
95
+ random.seed(42)
96
+ rows = random.sample(rows, max_samples)
97
+
98
+ texts = _extract_assistant_texts_from_conversations(rows)
99
+
100
+ # Filter out empty/too-short texts
101
+ filtered = [(t, provider, model_name) for t in texts if len(t) > 50]
102
+ if not filtered:
103
+ print(f" [SKIP] {dataset_id}: no valid texts extracted")
104
+ return [], [], []
105
+
106
+ t, p, m = zip(*filtered)
107
+ return list(t), list(p), list(m)
108
+
109
+
110
+ def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs):
111
+ """Load a-m-team DeepSeek dataset."""
112
+ max_samples = kwargs.get("max_samples")
113
+ load_kwargs = {}
114
+ if "name" in kwargs:
115
+ load_kwargs["name"] = kwargs["name"]
116
+
117
+ try:
118
+ ds = load_dataset(dataset_id, split="train", **load_kwargs)
119
+ except Exception as e1:
120
+ # Try without name kwarg as fallback
121
+ try:
122
+ ds = load_dataset(dataset_id, split="train", streaming=True)
123
+ rows = []
124
+ for row in ds:
125
+ rows.append(row)
126
+ if max_samples and len(rows) >= max_samples:
127
+ break
128
+ except Exception as e2:
129
+ print(f" [SKIP] {dataset_id}: {e2}")
130
+ return [], [], []
131
+ else:
132
+ rows = list(ds)
133
+ if max_samples and len(rows) > max_samples:
134
+ rows = rows[:max_samples]
135
+
136
+ texts = []
137
+ for row in rows:
138
+ text = _extract_from_am_dataset(row)
139
+ if len(text) > 50:
140
+ texts.append(text)
141
+
142
+ providers = [provider] * len(texts)
143
+ models = [model_name] * len(texts)
144
+ return texts, providers, models
145
+
146
+
147
+ def load_all_data():
148
+ """Load all datasets and return combined lists.
149
+
150
+ Returns:
151
+ texts: list of str
152
+ providers: list of str
153
+ models: list of str
154
+ is_ai: list of int (1=AI, 0=Human)
155
+ """
156
+ all_texts = []
157
+ all_providers = []
158
+ all_models = []
159
+
160
+ # TeichAI datasets
161
+ print("Loading TeichAI datasets...")
162
+ for dataset_id, provider, model_name, kwargs in tqdm(
163
+ DATASET_REGISTRY, desc="TeichAI"
164
+ ):
165
+ t0 = time.time()
166
+ texts, providers, models = load_teichai_dataset(
167
+ dataset_id, provider, model_name, kwargs
168
+ )
169
+ elapsed = time.time() - t0
170
+ all_texts.extend(texts)
171
+ all_providers.extend(providers)
172
+ all_models.extend(models)
173
+ print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
174
+
175
+ # DeepSeek a-m-team datasets
176
+ print("\nLoading DeepSeek (a-m-team) datasets...")
177
+ for dataset_id, provider, model_name, kwargs in tqdm(
178
+ DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM"
179
+ ):
180
+ t0 = time.time()
181
+ texts, providers, models = load_am_deepseek_dataset(
182
+ dataset_id, provider, model_name, kwargs
183
+ )
184
+ elapsed = time.time() - t0
185
+ all_texts.extend(texts)
186
+ all_providers.extend(providers)
187
+ all_models.extend(models)
188
+ print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)")
189
+
190
+ # Build is_ai labels (all AI)
191
+ is_ai = [1] * len(all_texts)
192
+
193
+ print(f"\n=== Total: {len(all_texts)} samples ===")
194
+ # Print per-provider counts
195
+ from collections import Counter
196
+
197
+ prov_counts = Counter(all_providers)
198
+ for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]):
199
+ print(f" {p}: {c}")
200
+
201
+ return all_texts, all_providers, all_models, is_ai
202
+
203
+
204
+ if __name__ == "__main__":
205
+ texts, providers, models, is_ai = load_all_data()
features.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Feature Extraction
3
+ TF-IDF pipeline + stylometric features.
4
+ Supports CoT-aware and no-CoT text preprocessing.
5
+ """
6
+
7
+ import re
8
+ import numpy as np
9
+ from scipy.sparse import hstack, csr_matrix
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.preprocessing import MaxAbsScaler
12
+ from sklearn.base import BaseEstimator, TransformerMixin
13
+
14
+ from config import TFIDF_WORD_PARAMS, TFIDF_CHAR_PARAMS
15
+
16
+
17
+ # --- Text Preprocessing ---
18
+
19
+ def strip_cot(text):
20
+ """Remove <think>...</think> blocks from text."""
21
+ return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
22
+
23
+
24
+ def has_cot(text):
25
+ """Check if text contains <think>...</think> blocks."""
26
+ return bool(re.search(r"<think>.*?</think>", text, flags=re.DOTALL))
27
+
28
+
29
+ def cot_ratio(text):
30
+ """Ratio of thinking text to total text length."""
31
+ think_matches = re.findall(r"<think>(.*?)</think>", text, flags=re.DOTALL)
32
+ if not think_matches or len(text) == 0:
33
+ return 0.0
34
+ think_len = sum(len(m) for m in think_matches)
35
+ return think_len / len(text)
36
+
37
+
38
+ # --- Stylometric Features ---
39
+
40
+ class StylometricFeatures(BaseEstimator, TransformerMixin):
41
+ """Extract stylometric features from text."""
42
+
43
+ def fit(self, X, y=None):
44
+ return self
45
+
46
+ def transform(self, X):
47
+ features = []
48
+ for text in X:
49
+ features.append(self._extract(text))
50
+ return csr_matrix(np.array(features, dtype=np.float32))
51
+
52
+ def _extract(self, text):
53
+ sentences = re.split(r'[.!?]+', text)
54
+ sentences = [s.strip() for s in sentences if s.strip()]
55
+ words = text.split()
56
+
57
+ n_chars = max(len(text), 1)
58
+ n_words = max(len(words), 1)
59
+ n_sentences = max(len(sentences), 1)
60
+
61
+ # Basic stats
62
+ avg_word_len = np.mean([len(w) for w in words]) if words else 0
63
+ avg_sent_len = n_words / n_sentences
64
+
65
+ # Punctuation densities
66
+ n_commas = text.count(",") / n_chars
67
+ n_semicolons = text.count(";") / n_chars
68
+ n_colons = text.count(":") / n_chars
69
+ n_exclaim = text.count("!") / n_chars
70
+ n_question = text.count("?") / n_chars
71
+ n_ellipsis = text.count("...") / n_chars
72
+ n_dash = (text.count("—") + text.count("--")) / n_chars
73
+
74
+ # Markdown elements
75
+ n_headers = len(re.findall(r'^#{1,6}\s', text, re.MULTILINE)) / n_sentences
76
+ n_bold = len(re.findall(r'\*\*.*?\*\*', text)) / n_sentences
77
+ n_italic = len(re.findall(r'(?<!\*)\*(?!\*).*?(?<!\*)\*(?!\*)', text)) / n_sentences
78
+ n_code_blocks = len(re.findall(r'```', text)) / n_sentences
79
+ n_inline_code = len(re.findall(r'`[^`]+`', text)) / n_sentences
80
+ n_bullet = len(re.findall(r'^[\s]*[-*+]\s', text, re.MULTILINE)) / n_sentences
81
+ n_numbered = len(re.findall(r'^\s*\d+[.)]\s', text, re.MULTILINE)) / n_sentences
82
+
83
+ # Vocabulary richness
84
+ unique_words = len(set(w.lower() for w in words))
85
+ ttr = unique_words / n_words # type-token ratio
86
+
87
+ # Paragraph structure
88
+ paragraphs = text.split("\n\n")
89
+ n_paragraphs = len([p for p in paragraphs if p.strip()])
90
+ avg_para_len = n_words / max(n_paragraphs, 1)
91
+
92
+ # Special patterns
93
+ starts_with_certainly = 1.0 if re.match(r'^(Certainly|Of course|Sure|Absolutely|Great question)', text, re.IGNORECASE) else 0.0
94
+ has_disclaimer = 1.0 if re.search(r"(I'm an AI|as an AI|language model|I cannot|I can't help)", text, re.IGNORECASE) else 0.0
95
+
96
+ # CoT features (present even in no-CoT mode, just will be 0)
97
+ has_think = 1.0 if has_cot(text) else 0.0
98
+ think_ratio = cot_ratio(text)
99
+
100
+ return [
101
+ avg_word_len, avg_sent_len,
102
+ n_commas, n_semicolons, n_colons, n_exclaim, n_question,
103
+ n_ellipsis, n_dash,
104
+ n_headers, n_bold, n_italic, n_code_blocks, n_inline_code,
105
+ n_bullet, n_numbered,
106
+ ttr, n_paragraphs, avg_para_len,
107
+ starts_with_certainly, has_disclaimer,
108
+ has_think, think_ratio,
109
+ n_chars, n_words,
110
+ ]
111
+
112
+
113
+ # --- Feature Pipeline ---
114
+
115
+ class FeaturePipeline:
116
+ """Combined TF-IDF + stylometric feature pipeline."""
117
+
118
+ def __init__(self):
119
+ self.word_tfidf = TfidfVectorizer(**TFIDF_WORD_PARAMS)
120
+ self.char_tfidf = TfidfVectorizer(**TFIDF_CHAR_PARAMS)
121
+ self.stylo = StylometricFeatures()
122
+ self.scaler = MaxAbsScaler()
123
+
124
+ def fit_transform(self, texts):
125
+ """Fit and transform texts into feature matrix."""
126
+ import time
127
+ print(f" Input: {len(texts)} texts")
128
+
129
+ # Strip <think> blocks for TF-IDF so n-grams learn style, not CoT
130
+ texts_no_cot = [strip_cot(t) for t in texts]
131
+
132
+ t0 = time.time()
133
+ word_features = self.word_tfidf.fit_transform(texts_no_cot)
134
+ print(f" word tfidf: {word_features.shape[1]} features ({time.time()-t0:.1f}s)")
135
+
136
+ t0 = time.time()
137
+ char_features = self.char_tfidf.fit_transform(texts_no_cot)
138
+ print(f" char tfidf: {char_features.shape[1]} features ({time.time()-t0:.1f}s)")
139
+
140
+ # Stylometric uses original text (has_think, think_ratio still work)
141
+ t0 = time.time()
142
+ stylo_features = self.stylo.fit_transform(texts)
143
+ print(f" stylometric: {stylo_features.shape[1]} features ({time.time()-t0:.1f}s)")
144
+
145
+ combined = hstack([word_features, char_features, stylo_features])
146
+ combined = self.scaler.fit_transform(combined)
147
+ print(f" Combined feature matrix: {combined.shape}")
148
+ return combined
149
+
150
+ def transform(self, texts):
151
+ """Transform texts into feature matrix (after fitting)."""
152
+ texts_no_cot = [strip_cot(t) for t in texts]
153
+ word_features = self.word_tfidf.transform(texts_no_cot)
154
+ char_features = self.char_tfidf.transform(texts_no_cot)
155
+ stylo_features = self.stylo.transform(texts)
156
+ combined = hstack([word_features, char_features, stylo_features])
157
+ return self.scaler.transform(combined)
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Neural Network
3
+ Single-headed MLP: predicts provider only.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class AIFinderNet(nn.Module):
11
+ """Single-headed classifier: predicts provider only."""
12
+
13
+ def __init__(
14
+ self,
15
+ input_dim,
16
+ num_providers,
17
+ hidden_dim=1024,
18
+ embed_dim=256,
19
+ dropout=0.3,
20
+ ):
21
+ super().__init__()
22
+ self.backbone = nn.Sequential(
23
+ nn.Linear(input_dim, hidden_dim),
24
+ nn.BatchNorm1d(hidden_dim),
25
+ nn.ReLU(),
26
+ nn.Dropout(dropout),
27
+ nn.Linear(hidden_dim, embed_dim),
28
+ nn.BatchNorm1d(embed_dim),
29
+ nn.ReLU(),
30
+ nn.Dropout(dropout),
31
+ )
32
+ self.provider_head = nn.Linear(embed_dim, num_providers)
33
+
34
+ def forward(self, x):
35
+ h = self.backbone(x)
36
+ provider_logits = self.provider_head(h)
37
+ return provider_logits
models/aifinder_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87706447863f7cae3a6295d06ecbfb35333b2f05f670d5b47133a76757b6377f
3
+ size 165033273
models/classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b0f36a1dd01e6375df5980017b582ff62469b59e5cc9d37b349fc5c48aa5734
3
+ size 165211381
models/feature_pipeline.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb73af66efcc5a3be022451e0e5ed6871d4df6cf0522e9f8d338f9079a57c267
3
+ size 2094058
models/model_enc.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7a8a7bc3087ebbd1634bc9663e9f9fe4701c4872e3bcac7fe37671eaa93f79
3
+ size 1999
models/provider_enc.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7552c92dfc1d08686d9c6f360321e8e45df52e78f2b2eb450ccf117f29aaf62d
3
+ size 727
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets>=4.0
2
+ scikit-learn>=1.5
3
+ numpy>=1.26
4
+ scipy>=1.12
5
+ joblib>=1.3
6
+ tqdm>=4.60
7
+ torch>=2.0
8
+ gradio>=5.0
9
+ pandas>=2.0
10
+ huggingface_hub>=0.23.0
11
+ flask>=3.0
12
+ flask-cors>=4.0
static/index.html ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>AIFinder - Identify AI Responses</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600&family=Outfit:wght@300;400;500;600;700&display=swap');
9
+
10
+ * {
11
+ margin: 0;
12
+ padding: 0;
13
+ box-sizing: border-box;
14
+ }
15
+
16
+ :root {
17
+ --bg-primary: #0d0d0d;
18
+ --bg-secondary: #171717;
19
+ --bg-tertiary: #1f1f1f;
20
+ --bg-elevated: #262626;
21
+ --text-primary: #f5f5f5;
22
+ --text-secondary: #a3a3a3;
23
+ --text-muted: #737373;
24
+ --accent: #e85d04;
25
+ --accent-hover: #f48c06;
26
+ --accent-muted: #9c4300;
27
+ --success: #22c55e;
28
+ --success-muted: #166534;
29
+ --border: #333333;
30
+ --border-light: #404040;
31
+ }
32
+
33
+ body {
34
+ font-family: 'Outfit', -apple-system, sans-serif;
35
+ background: var(--bg-primary);
36
+ color: var(--text-primary);
37
+ min-height: 100vh;
38
+ line-height: 1.6;
39
+ }
40
+
41
+ .container {
42
+ max-width: 900px;
43
+ margin: 0 auto;
44
+ padding: 2rem 1.5rem;
45
+ }
46
+
47
+ header {
48
+ text-align: center;
49
+ margin-bottom: 3rem;
50
+ padding-top: 1rem;
51
+ }
52
+
53
+ .logo {
54
+ font-size: 2.5rem;
55
+ font-weight: 700;
56
+ letter-spacing: -0.05em;
57
+ margin-bottom: 0.5rem;
58
+ }
59
+
60
+ .logo span {
61
+ color: var(--accent);
62
+ }
63
+
64
+ .tagline {
65
+ color: var(--text-secondary);
66
+ font-size: 1rem;
67
+ font-weight: 300;
68
+ }
69
+
70
+ .card {
71
+ background: var(--bg-secondary);
72
+ border: 1px solid var(--border);
73
+ border-radius: 12px;
74
+ padding: 1.5rem;
75
+ margin-bottom: 1.5rem;
76
+ transition: border-color 0.2s ease;
77
+ }
78
+
79
+ .card:focus-within {
80
+ border-color: var(--border-light);
81
+ }
82
+
83
+ .card-label {
84
+ font-size: 0.75rem;
85
+ text-transform: uppercase;
86
+ letter-spacing: 0.1em;
87
+ color: var(--text-muted);
88
+ margin-bottom: 0.75rem;
89
+ font-weight: 500;
90
+ }
91
+
92
+ textarea {
93
+ width: 100%;
94
+ background: var(--bg-tertiary);
95
+ border: 1px solid var(--border);
96
+ border-radius: 8px;
97
+ padding: 1rem;
98
+ color: var(--text-primary);
99
+ font-family: 'JetBrains Mono', monospace;
100
+ font-size: 0.875rem;
101
+ resize: vertical;
102
+ min-height: 180px;
103
+ transition: border-color 0.2s ease;
104
+ }
105
+
106
+ textarea:focus {
107
+ outline: none;
108
+ border-color: var(--accent-muted);
109
+ }
110
+
111
+ textarea::placeholder {
112
+ color: var(--text-muted);
113
+ }
114
+
115
+ .btn {
116
+ display: inline-flex;
117
+ align-items: center;
118
+ justify-content: center;
119
+ gap: 0.5rem;
120
+ padding: 0.75rem 1.5rem;
121
+ border-radius: 8px;
122
+ font-family: 'Outfit', sans-serif;
123
+ font-size: 0.9rem;
124
+ font-weight: 500;
125
+ cursor: pointer;
126
+ transition: all 0.2s ease;
127
+ border: none;
128
+ }
129
+
130
+ .btn-primary {
131
+ background: var(--accent);
132
+ color: white;
133
+ }
134
+
135
+ .btn-primary:hover:not(:disabled) {
136
+ background: var(--accent-hover);
137
+ }
138
+
139
+ .btn-primary:disabled {
140
+ opacity: 0.5;
141
+ cursor: not-allowed;
142
+ }
143
+
144
+ .btn-secondary {
145
+ background: var(--bg-tertiary);
146
+ color: var(--text-primary);
147
+ border: 1px solid var(--border);
148
+ }
149
+
150
+ .btn-secondary:hover:not(:disabled) {
151
+ background: var(--bg-elevated);
152
+ border-color: var(--border-light);
153
+ }
154
+
155
+ .btn-group {
156
+ display: flex;
157
+ gap: 0.75rem;
158
+ flex-wrap: wrap;
159
+ }
160
+
161
+ .results {
162
+ display: none;
163
+ }
164
+
165
+ .results.visible {
166
+ display: block;
167
+ animation: fadeIn 0.3s ease;
168
+ }
169
+
170
+ @keyframes fadeIn {
171
+ from { opacity: 0; transform: translateY(10px); }
172
+ to { opacity: 1; transform: translateY(0); }
173
+ }
174
+
175
+ .result-main {
176
+ display: flex;
177
+ align-items: center;
178
+ justify-content: space-between;
179
+ padding: 1.25rem;
180
+ background: var(--bg-tertiary);
181
+ border-radius: 8px;
182
+ margin-bottom: 1rem;
183
+ }
184
+
185
+ .result-provider {
186
+ font-size: 1.5rem;
187
+ font-weight: 600;
188
+ }
189
+
190
+ .result-confidence {
191
+ font-size: 1.25rem;
192
+ font-weight: 500;
193
+ color: var(--accent);
194
+ }
195
+
196
+ .result-bar {
197
+ height: 8px;
198
+ background: var(--bg-elevated);
199
+ border-radius: 4px;
200
+ margin-bottom: 1rem;
201
+ overflow: hidden;
202
+ }
203
+
204
+ .result-bar-fill {
205
+ height: 100%;
206
+ background: var(--accent);
207
+ border-radius: 4px;
208
+ transition: width 0.5s ease;
209
+ }
210
+
211
+ .result-list {
212
+ list-style: none;
213
+ }
214
+
215
+ .result-item {
216
+ display: flex;
217
+ align-items: center;
218
+ justify-content: space-between;
219
+ padding: 0.75rem 0;
220
+ border-bottom: 1px solid var(--border);
221
+ }
222
+
223
+ .result-item:last-child {
224
+ border-bottom: none;
225
+ }
226
+
227
+ .result-name {
228
+ font-weight: 500;
229
+ }
230
+
231
+ .result-percent {
232
+ font-family: 'JetBrains Mono', monospace;
233
+ color: var(--text-secondary);
234
+ font-size: 0.875rem;
235
+ }
236
+
237
+ .correction {
238
+ display: none;
239
+ margin-top: 1.5rem;
240
+ padding-top: 1.5rem;
241
+ border-top: 1px solid var(--border);
242
+ }
243
+
244
+ .correction.visible {
245
+ display: block;
246
+ animation: fadeIn 0.3s ease;
247
+ }
248
+
249
+ .correction-title {
250
+ font-size: 0.875rem;
251
+ font-weight: 500;
252
+ margin-bottom: 0.75rem;
253
+ color: var(--text-secondary);
254
+ }
255
+
256
+ select {
257
+ width: 100%;
258
+ padding: 0.75rem 1rem;
259
+ background: var(--bg-tertiary);
260
+ border: 1px solid var(--border);
261
+ border-radius: 8px;
262
+ color: var(--text-primary);
263
+ font-family: 'Outfit', sans-serif;
264
+ font-size: 0.9rem;
265
+ margin-bottom: 0.75rem;
266
+ cursor: pointer;
267
+ }
268
+
269
+ select:focus {
270
+ outline: none;
271
+ border-color: var(--accent-muted);
272
+ }
273
+
274
+ .stats {
275
+ display: flex;
276
+ gap: 1.5rem;
277
+ margin-bottom: 1.5rem;
278
+ flex-wrap: wrap;
279
+ }
280
+
281
+ .stat {
282
+ background: var(--bg-secondary);
283
+ border: 1px solid var(--border);
284
+ border-radius: 8px;
285
+ padding: 1rem 1.25rem;
286
+ flex: 1;
287
+ min-width: 120px;
288
+ }
289
+
290
+ .stat-value {
291
+ font-size: 1.5rem;
292
+ font-weight: 600;
293
+ color: var(--accent);
294
+ }
295
+
296
+ .stat-label {
297
+ font-size: 0.75rem;
298
+ color: var(--text-muted);
299
+ text-transform: uppercase;
300
+ letter-spacing: 0.05em;
301
+ }
302
+
303
+ .actions {
304
+ display: flex;
305
+ gap: 0.75rem;
306
+ margin-top: 1rem;
307
+ }
308
+
309
+ .toast {
310
+ position: fixed;
311
+ bottom: 2rem;
312
+ right: 2rem;
313
+ background: var(--bg-elevated);
314
+ border: 1px solid var(--border);
315
+ border-radius: 8px;
316
+ padding: 1rem 1.5rem;
317
+ color: var(--text-primary);
318
+ font-size: 0.9rem;
319
+ opacity: 0;
320
+ transform: translateY(20px);
321
+ transition: all 0.3s ease;
322
+ z-index: 1000;
323
+ }
324
+
325
+ .toast.visible {
326
+ opacity: 1;
327
+ transform: translateY(0);
328
+ }
329
+
330
+ .toast.success {
331
+ border-color: var(--success-muted);
332
+ }
333
+
334
+ .footer {
335
+ text-align: center;
336
+ margin-top: 3rem;
337
+ padding: 1.5rem;
338
+ color: var(--text-muted);
339
+ font-size: 0.8rem;
340
+ }
341
+
342
+ .footer a {
343
+ color: var(--text-secondary);
344
+ text-decoration: none;
345
+ }
346
+
347
+ .footer a:hover {
348
+ color: var(--accent);
349
+ }
350
+
351
+ .loading {
352
+ display: inline-block;
353
+ width: 16px;
354
+ height: 16px;
355
+ border: 2px solid var(--text-muted);
356
+ border-top-color: var(--accent);
357
+ border-radius: 50%;
358
+ animation: spin 0.8s linear infinite;
359
+ }
360
+
361
+ @keyframes spin {
362
+ to { transform: rotate(360deg); }
363
+ }
364
+
365
+ .status-indicator {
366
+ display: inline-flex;
367
+ align-items: center;
368
+ gap: 0.5rem;
369
+ font-size: 0.8rem;
370
+ color: var(--text-muted);
371
+ margin-bottom: 1rem;
372
+ }
373
+
374
+ .status-dot {
375
+ width: 8px;
376
+ height: 8px;
377
+ border-radius: 50%;
378
+ background: var(--success);
379
+ }
380
+
381
+ .status-dot.loading {
382
+ background: var(--accent);
383
+ animation: pulse 1s ease infinite;
384
+ }
385
+
386
+ @keyframes pulse {
387
+ 0%, 100% { opacity: 1; }
388
+ 50% { opacity: 0.5; }
389
+ }
390
+
391
+ .empty-state {
392
+ text-align: center;
393
+ padding: 3rem 1rem;
394
+ color: var(--text-muted);
395
+ }
396
+
397
+ .empty-state-icon {
398
+ font-size: 3rem;
399
+ margin-bottom: 1rem;
400
+ opacity: 0.5;
401
+ }
402
+
403
+ @media (max-width: 600px) {
404
+ .container {
405
+ padding: 1rem;
406
+ }
407
+
408
+ .logo {
409
+ font-size: 2rem;
410
+ }
411
+
412
+ .btn-group {
413
+ flex-direction: column;
414
+ }
415
+
416
+ .btn {
417
+ width: 100%;
418
+ }
419
+
420
+ .result-main {
421
+ flex-direction: column;
422
+ gap: 0.5rem;
423
+ text-align: center;
424
+ }
425
+ }
426
+ </style>
427
+ </head>
428
+ <body>
429
+ <div class="container">
430
+ <header>
431
+ <div class="logo">AI<span>Finder</span></div>
432
+ <p class="tagline">Identify which AI provider generated a response</p>
433
+ </header>
434
+
435
+ <div class="status-indicator">
436
+ <span class="status-dot" id="statusDot"></span>
437
+ <span id="statusText">Connecting to API...</span>
438
+ </div>
439
+
440
+ <div class="card">
441
+ <div class="card-label">Paste AI Response</div>
442
+ <textarea id="inputText" placeholder="Paste an AI response here to identify which provider generated it..."></textarea>
443
+ </div>
444
+
445
+ <div class="btn-group">
446
+ <button class="btn btn-primary" id="classifyBtn" disabled>
447
+ <span id="classifyBtnText">Classify</span>
448
+ </button>
449
+ <button class="btn btn-secondary" id="clearBtn">Clear</button>
450
+ </div>
451
+
452
+ <div class="results" id="results">
453
+ <div class="card">
454
+ <div class="card-label">Result</div>
455
+ <div class="result-main">
456
+ <span class="result-provider" id="resultProvider">-</span>
457
+ <span class="result-confidence" id="resultConfidence">-</span>
458
+ </div>
459
+ <div class="result-bar">
460
+ <div class="result-bar-fill" id="resultBar" style="width: 0%"></div>
461
+ </div>
462
+ <ul class="result-list" id="resultList"></ul>
463
+ </div>
464
+
465
+ <div class="correction" id="correction">
466
+ <div class="correction-title">Wrong? Correct the provider to train the model:</div>
467
+ <select id="providerSelect"></select>
468
+ <button class="btn btn-primary" id="trainBtn">Train & Save</button>
469
+ </div>
470
+ </div>
471
+
472
+ <div class="stats" id="stats" style="display: none;">
473
+ <div class="stat">
474
+ <div class="stat-value" id="correctionsCount">0</div>
475
+ <div class="stat-label">Corrections</div>
476
+ </div>
477
+ <div class="stat">
478
+ <div class="stat-value" id="sessionCount">0</div>
479
+ <div class="stat-label">Session</div>
480
+ </div>
481
+ </div>
482
+
483
+ <div class="actions" id="actions" style="display: none;">
484
+ <button class="btn btn-secondary" id="exportBtn">Export Trained Model</button>
485
+ <button class="btn btn-secondary" id="resetBtn">Reset Training</button>
486
+ </div>
487
+
488
+ <div class="footer">
489
+ <p>AIFinder &mdash; Train on corrections to improve accuracy</p>
490
+ <p style="margin-top: 0.5rem;">
491
+ Want to contribute? Test this and post to the
492
+ <a href="https://huggingface.co/spaces" target="_blank">HuggingFace Spaces Community</a>
493
+ if you want it merged!
494
+ </p>
495
+ </div>
496
+ </div>
497
+
498
+ <div class="toast" id="toast"></div>
499
+
500
+ <script>
501
+ const API_BASE = window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1'
502
+ ? 'http://localhost:7860'
503
+ : '';
504
+
505
+ let providers = [];
506
+ let correctionsCount = 0;
507
+ let sessionCorrections = 0;
508
+
509
+ const inputText = document.getElementById('inputText');
510
+ const classifyBtn = document.getElementById('classifyBtn');
511
+ const classifyBtnText = document.getElementById('classifyBtnText');
512
+ const clearBtn = document.getElementById('clearBtn');
513
+ const results = document.getElementById('results');
514
+ const resultProvider = document.getElementById('resultProvider');
515
+ const resultConfidence = document.getElementById('resultConfidence');
516
+ const resultBar = document.getElementById('resultBar');
517
+ const resultList = document.getElementById('resultList');
518
+ const correction = document.getElementById('correction');
519
+ const providerSelect = document.getElementById('providerSelect');
520
+ const trainBtn = document.getElementById('trainBtn');
521
+ const stats = document.getElementById('stats');
522
+ const correctionsCountEl = document.getElementById('correctionsCount');
523
+ const sessionCountEl = document.getElementById('sessionCount');
524
+ const actions = document.getElementById('actions');
525
+ const exportBtn = document.getElementById('exportBtn');
526
+ const resetBtn = document.getElementById('resetBtn');
527
+ const toast = document.getElementById('toast');
528
+ const statusDot = document.getElementById('statusDot');
529
+ const statusText = document.getElementById('statusText');
530
+
531
+ function showToast(message, type = 'info') {
532
+ toast.textContent = message;
533
+ toast.className = 'toast visible' + (type === 'success' ? ' success' : '');
534
+ setTimeout(() => {
535
+ toast.classList.remove('visible');
536
+ }, 3000);
537
+ }
538
+
539
+ async function checkStatus() {
540
+ try {
541
+ const res = await fetch(`${API_BASE}/api/status`);
542
+ const data = await res.json();
543
+ if (data.loaded) {
544
+ statusDot.classList.remove('loading');
545
+ statusText.textContent = `Ready (${data.device})`;
546
+ classifyBtn.disabled = false;
547
+ loadProviders();
548
+ loadStats();
549
+ } else {
550
+ setTimeout(checkStatus, 1000);
551
+ }
552
+ } catch (e) {
553
+ statusDot.classList.add('loading');
554
+ statusText.textContent = 'Connecting to API...';
555
+ setTimeout(checkStatus, 2000);
556
+ }
557
+ }
558
+
559
+ async function loadProviders() {
560
+ const res = await fetch(`${API_BASE}/api/providers`);
561
+ const data = await res.json();
562
+ providers = data.providers;
563
+
564
+ providerSelect.innerHTML = providers.map(p =>
565
+ `<option value="${p}">${p}</option>`
566
+ ).join('');
567
+ }
568
+
569
+ function loadStats() {
570
+ const saved = localStorage.getItem('aifinder_corrections');
571
+ if (saved) {
572
+ correctionsCount = parseInt(saved, 10);
573
+ correctionsCountEl.textContent = correctionsCount;
574
+ stats.style.display = 'flex';
575
+ actions.style.display = 'flex';
576
+ }
577
+ sessionCountEl.textContent = sessionCorrections;
578
+ }
579
+
580
+ function saveStats() {
581
+ localStorage.setItem('aifinder_corrections', correctionsCount.toString());
582
+ }
583
+
584
+ async function classify() {
585
+ const text = inputText.value.trim();
586
+ if (text.length < 20) {
587
+ showToast('Text must be at least 20 characters');
588
+ return;
589
+ }
590
+
591
+ classifyBtn.disabled = true;
592
+ classifyBtnText.innerHTML = '<span class="loading"></span>';
593
+
594
+ try {
595
+ const res = await fetch(`${API_BASE}/api/classify`, {
596
+ method: 'POST',
597
+ headers: { 'Content-Type': 'application/json' },
598
+ body: JSON.stringify({ text })
599
+ });
600
+
601
+ if (!res.ok) {
602
+ throw new Error('Classification failed');
603
+ }
604
+
605
+ const data = await res.json();
606
+ showResults(data);
607
+ } catch (e) {
608
+ showToast('Error: ' + e.message);
609
+ } finally {
610
+ classifyBtn.disabled = false;
611
+ classifyBtnText.textContent = 'Classify';
612
+ }
613
+ }
614
+
615
+ function showResults(data) {
616
+ resultProvider.textContent = data.provider;
617
+ resultConfidence.textContent = data.confidence.toFixed(1) + '%';
618
+ resultBar.style.width = data.confidence + '%';
619
+
620
+ resultList.innerHTML = data.top_providers.map(p => `
621
+ <li class="result-item">
622
+ <span class="result-name">${p.name}</span>
623
+ <span class="result-percent">${p.confidence.toFixed(1)}%</span>
624
+ </li>
625
+ `).join('');
626
+
627
+ providerSelect.value = data.provider;
628
+
629
+ results.classList.add('visible');
630
+ correction.classList.add('visible');
631
+
632
+ if (correctionsCount > 0 || sessionCorrections > 0) {
633
+ stats.style.display = 'flex';
634
+ actions.style.display = 'flex';
635
+ }
636
+ }
637
+
638
+ async function train() {
639
+ const text = inputText.value.trim();
640
+ const correctProvider = providerSelect.value;
641
+
642
+ trainBtn.disabled = true;
643
+ trainBtn.innerHTML = '<span class="loading"></span>';
644
+
645
+ try {
646
+ const res = await fetch(`${API_BASE}/api/correct`, {
647
+ method: 'POST',
648
+ headers: { 'Content-Type': 'application/json' },
649
+ body: JSON.stringify({ text, correct_provider: correctProvider })
650
+ });
651
+
652
+ if (!res.ok) {
653
+ throw new Error('Training failed');
654
+ }
655
+
656
+ const data = await res.json();
657
+ correctionsCount++;
658
+ sessionCorrections++;
659
+ saveStats();
660
+ correctionsCountEl.textContent = correctionsCount;
661
+ sessionCountEl.textContent = sessionCorrections;
662
+
663
+ showToast(`Trained! Loss: ${data.loss.toFixed(4)}`, 'success');
664
+
665
+ stats.style.display = 'flex';
666
+ actions.style.display = 'flex';
667
+
668
+ classify();
669
+ } catch (e) {
670
+ showToast('Error: ' + e.message);
671
+ } finally {
672
+ trainBtn.disabled = false;
673
+ trainBtn.textContent = 'Train & Save';
674
+ }
675
+ }
676
+
677
+ async function exportModel() {
678
+ exportBtn.disabled = true;
679
+ exportBtn.innerHTML = '<span class="loading"></span>';
680
+
681
+ try {
682
+ const res = await fetch(`${API_BASE}/api/save`, {
683
+ method: 'POST',
684
+ headers: { 'Content-Type': 'application/json' },
685
+ body: JSON.stringify({ filename: 'aifinder_trained.pt' })
686
+ });
687
+
688
+ if (!res.ok) {
689
+ throw new Error('Save failed');
690
+ }
691
+
692
+ const data = await res.json();
693
+
694
+ const link = document.createElement('a');
695
+ link.href = `${API_BASE}/models/${data.filename}`;
696
+ link.download = data.filename;
697
+ link.click();
698
+
699
+ showToast('Model exported!', 'success');
700
+ } catch (e) {
701
+ showToast('Error: ' + e.message);
702
+ } finally {
703
+ exportBtn.disabled = false;
704
+ exportBtn.textContent = 'Export Trained Model';
705
+ }
706
+ }
707
+
708
+ function resetTraining() {
709
+ if (!confirm('Reset all training data? This cannot be undone.')) {
710
+ return;
711
+ }
712
+
713
+ correctionsCount = 0;
714
+ sessionCorrections = 0;
715
+ localStorage.removeItem('aifinder_corrections');
716
+ correctionsCountEl.textContent = '0';
717
+ sessionCountEl.textContent = '0';
718
+ stats.style.display = 'none';
719
+ actions.style.display = 'none';
720
+ showToast('Training data reset');
721
+ }
722
+
723
+ classifyBtn.addEventListener('click', classify);
724
+ clearBtn.addEventListener('click', () => {
725
+ inputText.value = '';
726
+ results.classList.remove('visible');
727
+ correction.classList.remove('visible');
728
+ });
729
+ trainBtn.addEventListener('click', train);
730
+ exportBtn.addEventListener('click', exportModel);
731
+ resetBtn.addEventListener('click', resetTraining);
732
+
733
+ inputText.addEventListener('keydown', (e) => {
734
+ if (e.key === 'Enter' && e.ctrlKey) {
735
+ classify();
736
+ }
737
+ });
738
+
739
+ checkStatus();
740
+ </script>
741
+ </body>
742
+ </html>