CompactAI commited on
Commit
0051294
·
verified ·
1 Parent(s): 838785b

Upload 15 files

Browse files
Files changed (6) hide show
  1. app.py +78 -0
  2. example_api.py +89 -0
  3. models/aifinder_trained.pt +2 -2
  4. requirements.txt +1 -0
  5. static/index.html +626 -43
  6. train.py +305 -0
app.py CHANGED
@@ -1,9 +1,14 @@
1
  """
2
  AIFinder API Server
3
  Serves classification and training endpoints for the frontend.
 
 
 
 
4
  """
5
 
6
  import os
 
7
  import sys
8
  import json
9
  import joblib
@@ -12,6 +17,8 @@ import torch
12
  import torch.nn as nn
13
  from flask import Flask, request, jsonify, send_from_directory
14
  from flask_cors import CORS
 
 
15
 
16
  from config import MODEL_DIR
17
  from model import AIFinderNet
@@ -19,6 +26,9 @@ from features import FeaturePipeline
19
 
20
  app = Flask(__name__, static_folder="static", static_url_path="")
21
  CORS(app)
 
 
 
22
 
23
  pipeline = None
24
  provider_enc = None
@@ -96,6 +106,74 @@ def classify():
96
  )
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  @app.route("/api/correct", methods=["POST"])
100
  def correct():
101
  """Train on a corrected example."""
 
1
  """
2
  AIFinder API Server
3
  Serves classification and training endpoints for the frontend.
4
+
5
+ Public API:
6
+ POST /v1/classify — classify text, returns top-N provider predictions.
7
+ No API key required. Rate-limited to 60 requests/minute per IP.
8
  """
9
 
10
  import os
11
+ import re
12
  import sys
13
  import json
14
  import joblib
 
17
  import torch.nn as nn
18
  from flask import Flask, request, jsonify, send_from_directory
19
  from flask_cors import CORS
20
+ from flask_limiter import Limiter
21
+ from flask_limiter.util import get_remote_address
22
 
23
  from config import MODEL_DIR
24
  from model import AIFinderNet
 
26
 
27
  app = Flask(__name__, static_folder="static", static_url_path="")
28
  CORS(app)
29
+ limiter = Limiter(get_remote_address, app=app, default_limits=[])
30
+
31
+ DEFAULT_TOP_N = 5
32
 
33
  pipeline = None
34
  provider_enc = None
 
106
  )
107
 
108
 
109
+ def _strip_think_tags(text):
110
+ """Remove <think>…</think> (and <thinking>…</thinking>) blocks from input."""
111
+ text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL)
112
+ return text.strip()
113
+
114
+
115
+ @app.route("/v1/classify", methods=["POST"])
116
+ @limiter.limit("60/minute")
117
+ def v1_classify():
118
+ """Public API — classify text and return top-N provider predictions.
119
+
120
+ Request JSON:
121
+ text (str): The text to classify. Any <think>/<thinking> tags will be
122
+ stripped automatically before classification.
123
+ top_n (int): Number of results to return (default: 5).
124
+
125
+ Response JSON:
126
+ provider (str): Best-matching provider name.
127
+ confidence (float): Confidence % for the top provider.
128
+ top_providers (list): List of {name, confidence} dicts.
129
+
130
+ Rate limit: 60 requests per minute per IP. No API key required.
131
+
132
+ NOTE: If the text you are classifying was produced by a model that emits
133
+ <think> or <thinking> blocks, you should strip those tags BEFORE
134
+ sending the text. This endpoint does it for you automatically, but
135
+ doing it on your side avoids wasting bytes on the wire.
136
+ """
137
+ data = request.get_json(silent=True)
138
+ if not data or "text" not in data:
139
+ return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400
140
+
141
+ raw_text = data["text"]
142
+ text = _strip_think_tags(raw_text)
143
+ top_n = data.get("top_n", DEFAULT_TOP_N)
144
+
145
+ if not isinstance(top_n, int) or top_n < 1:
146
+ top_n = DEFAULT_TOP_N
147
+
148
+ if len(text) < 20:
149
+ return jsonify({"error": "Text too short (minimum 20 characters after stripping think tags)."}), 400
150
+
151
+ X = pipeline.transform([text])
152
+ X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
153
+
154
+ with torch.no_grad():
155
+ prov_logits = net(X_t)
156
+
157
+ prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
158
+
159
+ top_idxs = np.argsort(prov_proba)[::-1][:top_n]
160
+ top_providers = [
161
+ {
162
+ "name": provider_enc.inverse_transform([i])[0],
163
+ "confidence": round(float(prov_proba[i] * 100), 2),
164
+ }
165
+ for i in top_idxs
166
+ ]
167
+
168
+ return jsonify(
169
+ {
170
+ "provider": top_providers[0]["name"],
171
+ "confidence": top_providers[0]["confidence"],
172
+ "top_providers": top_providers,
173
+ }
174
+ )
175
+
176
+
177
  @app.route("/api/correct", methods=["POST"])
178
  def correct():
179
  """Train on a corrected example."""
example_api.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example: Call the AIFinder public API to classify text.
4
+
5
+ Usage:
6
+ python example_api.py
7
+
8
+ Requirements:
9
+ pip install requests
10
+
11
+ IMPORTANT — Strip <think>/<thinking> tags!
12
+ Many reasoning models wrap chain-of-thought in <think>…</think> or
13
+ <thinking>…</thinking> blocks. These tags confuse the classifier because
14
+ they are NOT part of the model's actual output style. The API strips them
15
+ automatically, but you should also strip them on your side to avoid sending
16
+ unnecessary data.
17
+
18
+ API details:
19
+ POST /v1/classify
20
+ No API key required. Rate limit: 60 requests/minute per IP.
21
+ """
22
+
23
+ import re
24
+ import json
25
+ import requests
26
+
27
+ # Change this to your server URL (local or HuggingFace Space).
28
+ API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify"
29
+
30
+ # Number of top results to return (default on server is 5).
31
+ TOP_N = 5
32
+
33
+ EXAMPLE_TEXT = """\
34
+ I'd be happy to help you understand how neural networks work!
35
+
36
+ Neural networks are computational models inspired by the human brain. They consist of layers of interconnected nodes (neurons) that process information. Here's a breakdown:
37
+
38
+ 1. **Input Layer**: Receives the raw data
39
+ 2. **Hidden Layers**: Process and transform the data through weighted connections
40
+ 3. **Output Layer**: Produces the final prediction
41
+
42
+ Each connection has a weight, and each neuron has a bias. During training, the network adjusts these weights using backpropagation to minimize the difference between predicted and actual outputs.
43
+
44
+ The key insight is that by stacking multiple layers, neural networks can learn increasingly abstract representations of data, enabling them to solve complex tasks like image recognition, natural language processing, and more.
45
+
46
+ Would you like me to dive deeper into any specific aspect?
47
+ """
48
+
49
+
50
+ def strip_think_tags(text: str) -> str:
51
+ """Remove <think>…</think> and <thinking>…</thinking> blocks."""
52
+ return re.sub(
53
+ r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL
54
+ ).strip()
55
+
56
+
57
+ def classify(text: str, top_n: int = TOP_N) -> dict:
58
+ """Send text to the AIFinder API and return the JSON response."""
59
+ cleaned = strip_think_tags(text)
60
+ resp = requests.post(
61
+ API_URL,
62
+ json={"text": cleaned, "top_n": top_n},
63
+ timeout=30,
64
+ )
65
+ resp.raise_for_status()
66
+ return resp.json()
67
+
68
+
69
+ def main():
70
+ print("AIFinder API Example")
71
+ print("=" * 50)
72
+ print(f"Endpoint : {API_URL}")
73
+ print(f"Top-N : {TOP_N}")
74
+ print()
75
+
76
+ result = classify(EXAMPLE_TEXT)
77
+
78
+ print(f"Best match : {result['provider']} ({result['confidence']:.1f}%)")
79
+ print()
80
+ print("Top providers:")
81
+ for entry in result["top_providers"]:
82
+ bar = "█" * int(entry["confidence"] / 5) + "░" * (
83
+ 20 - int(entry["confidence"] / 5)
84
+ )
85
+ print(f" {entry['name']:.<25s} {entry['confidence']:5.1f}% {bar}")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
models/aifinder_trained.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:87706447863f7cae3a6295d06ecbfb35333b2f05f670d5b47133a76757b6377f
3
- size 165033273
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daeb1c62c7ca34fd45b697fb0240608555aa1617e45baa8d81620a69272bbaaf
3
+ size 165033081
requirements.txt CHANGED
@@ -10,3 +10,4 @@ pandas>=2.0
10
  huggingface_hub>=0.23.0
11
  flask>=3.0
12
  flask-cors>=4.0
 
 
10
  huggingface_hub>=0.23.0
11
  flask>=3.0
12
  flask-cors>=4.0
13
+ flask-limiter>=3.0
static/index.html CHANGED
@@ -399,7 +399,260 @@
399
  margin-bottom: 1rem;
400
  opacity: 0.5;
401
  }
402
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  @media (max-width: 600px) {
404
  .container {
405
  padding: 1rem;
@@ -432,59 +685,322 @@
432
  <p class="tagline">Identify which AI provider generated a response</p>
433
  </header>
434
 
435
- <div class="status-indicator">
436
- <span class="status-dot" id="statusDot"></span>
437
- <span id="statusText">Connecting to API...</span>
438
  </div>
439
-
440
- <div class="card">
441
- <div class="card-label">Paste AI Response</div>
442
- <textarea id="inputText" placeholder="Paste an AI response here to identify which provider generated it..."></textarea>
443
- </div>
444
-
445
- <div class="btn-group">
446
- <button class="btn btn-primary" id="classifyBtn" disabled>
447
- <span id="classifyBtnText">Classify</span>
448
- </button>
449
- <button class="btn btn-secondary" id="clearBtn">Clear</button>
450
- </div>
451
-
452
- <div class="results" id="results">
453
  <div class="card">
454
- <div class="card-label">Result</div>
455
- <div class="result-main">
456
- <span class="result-provider" id="resultProvider">-</span>
457
- <span class="result-confidence" id="resultConfidence">-</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  </div>
459
- <div class="result-bar">
460
- <div class="result-bar-fill" id="resultBar" style="width: 0%"></div>
 
461
  </div>
462
- <ul class="result-list" id="resultList"></ul>
463
  </div>
464
 
465
- <div class="correction" id="correction">
466
- <div class="correction-title">Wrong? Correct the provider to train the model:</div>
467
- <select id="providerSelect"></select>
468
- <button class="btn btn-primary" id="trainBtn">Train & Save</button>
469
  </div>
470
  </div>
471
-
472
- <div class="stats" id="stats" style="display: none;">
473
- <div class="stat">
474
- <div class="stat-value" id="correctionsCount">0</div>
475
- <div class="stat-label">Corrections</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  </div>
477
- <div class="stat">
478
- <div class="stat-value" id="sessionCount">0</div>
479
- <div class="stat-label">Session</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  </div>
481
  </div>
482
-
483
- <div class="actions" id="actions" style="display: none;">
484
- <button class="btn btn-secondary" id="exportBtn">Export Trained Model</button>
485
- <button class="btn btn-secondary" id="resetBtn">Reset Training</button>
486
- </div>
487
-
488
  <div class="footer">
489
  <p>AIFinder &mdash; Train on corrections to improve accuracy</p>
490
  <p style="margin-top: 0.5rem;">
@@ -736,6 +1252,73 @@
736
  }
737
  });
738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  checkStatus();
740
  </script>
741
  </body>
 
399
  margin-bottom: 1rem;
400
  opacity: 0.5;
401
  }
402
+
403
+ /* ── Tabs ── */
404
+ .tabs {
405
+ display: flex;
406
+ gap: 0;
407
+ margin-bottom: 2rem;
408
+ border-bottom: 1px solid var(--border);
409
+ }
410
+
411
+ .tab {
412
+ padding: 0.75rem 1.5rem;
413
+ font-family: 'Outfit', sans-serif;
414
+ font-size: 0.9rem;
415
+ font-weight: 500;
416
+ color: var(--text-muted);
417
+ background: none;
418
+ border: none;
419
+ border-bottom: 2px solid transparent;
420
+ cursor: pointer;
421
+ transition: all 0.2s ease;
422
+ }
423
+
424
+ .tab:hover {
425
+ color: var(--text-secondary);
426
+ }
427
+
428
+ .tab.active {
429
+ color: var(--accent);
430
+ border-bottom-color: var(--accent);
431
+ }
432
+
433
+ .tab-content {
434
+ display: none;
435
+ }
436
+
437
+ .tab-content.active {
438
+ display: block;
439
+ animation: fadeIn 0.3s ease;
440
+ }
441
+
442
+ /* ── API Docs ── */
443
+ .docs-section {
444
+ margin-bottom: 2rem;
445
+ }
446
+
447
+ .docs-section h2 {
448
+ font-size: 1.25rem;
449
+ font-weight: 600;
450
+ margin-bottom: 0.75rem;
451
+ color: var(--text-primary);
452
+ }
453
+
454
+ .docs-section h3 {
455
+ font-size: 1rem;
456
+ font-weight: 500;
457
+ margin-top: 1.25rem;
458
+ margin-bottom: 0.5rem;
459
+ color: var(--text-secondary);
460
+ }
461
+
462
+ .docs-section p {
463
+ color: var(--text-secondary);
464
+ font-size: 0.9rem;
465
+ margin-bottom: 0.75rem;
466
+ line-height: 1.7;
467
+ }
468
+
469
+ .docs-endpoint {
470
+ display: inline-flex;
471
+ align-items: center;
472
+ gap: 0.5rem;
473
+ background: var(--bg-tertiary);
474
+ border: 1px solid var(--border);
475
+ border-radius: 6px;
476
+ padding: 0.5rem 1rem;
477
+ margin-bottom: 1rem;
478
+ font-family: 'JetBrains Mono', monospace;
479
+ font-size: 0.85rem;
480
+ }
481
+
482
+ .docs-method {
483
+ color: var(--success);
484
+ font-weight: 600;
485
+ }
486
+
487
+ .docs-path {
488
+ color: var(--text-primary);
489
+ }
490
+
491
+ .docs-badge {
492
+ display: inline-block;
493
+ font-size: 0.7rem;
494
+ font-weight: 600;
495
+ text-transform: uppercase;
496
+ letter-spacing: 0.05em;
497
+ padding: 0.2rem 0.6rem;
498
+ border-radius: 4px;
499
+ margin-left: 0.5rem;
500
+ }
501
+
502
+ .docs-badge.free {
503
+ background: var(--success-muted);
504
+ color: var(--success);
505
+ }
506
+
507
+ .docs-badge.limit {
508
+ background: var(--accent-muted);
509
+ color: var(--accent-hover);
510
+ }
511
+
512
+ .docs-code-block {
513
+ position: relative;
514
+ background: var(--bg-tertiary);
515
+ border: 1px solid var(--border);
516
+ border-radius: 8px;
517
+ margin-bottom: 1rem;
518
+ overflow: hidden;
519
+ }
520
+
521
+ .docs-code-header {
522
+ display: flex;
523
+ align-items: center;
524
+ justify-content: space-between;
525
+ padding: 0.5rem 1rem;
526
+ background: var(--bg-elevated);
527
+ border-bottom: 1px solid var(--border);
528
+ font-size: 0.75rem;
529
+ color: var(--text-muted);
530
+ text-transform: uppercase;
531
+ letter-spacing: 0.05em;
532
+ }
533
+
534
+ .docs-copy-btn {
535
+ background: none;
536
+ border: 1px solid var(--border);
537
+ border-radius: 4px;
538
+ color: var(--text-muted);
539
+ font-size: 0.7rem;
540
+ padding: 0.2rem 0.5rem;
541
+ cursor: pointer;
542
+ font-family: 'Outfit', sans-serif;
543
+ transition: all 0.2s ease;
544
+ }
545
+
546
+ .docs-copy-btn:hover {
547
+ color: var(--text-primary);
548
+ border-color: var(--border-light);
549
+ }
550
+
551
+ .docs-code-block pre {
552
+ padding: 1rem;
553
+ overflow-x: auto;
554
+ font-family: 'JetBrains Mono', monospace;
555
+ font-size: 0.8rem;
556
+ line-height: 1.6;
557
+ color: var(--text-primary);
558
+ margin: 0;
559
+ }
560
+
561
+ .docs-table {
562
+ width: 100%;
563
+ border-collapse: collapse;
564
+ font-size: 0.85rem;
565
+ margin-bottom: 1rem;
566
+ }
567
+
568
+ .docs-table th {
569
+ text-align: left;
570
+ padding: 0.6rem 0.75rem;
571
+ background: var(--bg-elevated);
572
+ color: var(--text-secondary);
573
+ font-weight: 500;
574
+ border-bottom: 1px solid var(--border);
575
+ font-size: 0.75rem;
576
+ text-transform: uppercase;
577
+ letter-spacing: 0.05em;
578
+ }
579
+
580
+ .docs-table td {
581
+ padding: 0.6rem 0.75rem;
582
+ border-bottom: 1px solid var(--border);
583
+ color: var(--text-secondary);
584
+ }
585
+
586
+ .docs-table tr:last-child td {
587
+ border-bottom: none;
588
+ }
589
+
590
+ .docs-table code {
591
+ font-family: 'JetBrains Mono', monospace;
592
+ font-size: 0.8rem;
593
+ background: var(--bg-tertiary);
594
+ padding: 0.15rem 0.4rem;
595
+ border-radius: 3px;
596
+ color: var(--accent-hover);
597
+ }
598
+
599
+ .docs-warning {
600
+ background: rgba(232, 93, 4, 0.08);
601
+ border: 1px solid var(--accent-muted);
602
+ border-radius: 8px;
603
+ padding: 1rem 1.25rem;
604
+ margin-bottom: 1rem;
605
+ font-size: 0.85rem;
606
+ color: var(--text-secondary);
607
+ line-height: 1.7;
608
+ }
609
+
610
+ .docs-warning strong {
611
+ color: var(--accent-hover);
612
+ }
613
+
614
+ .docs-inline-code {
615
+ font-family: 'JetBrains Mono', monospace;
616
+ font-size: 0.8rem;
617
+ background: var(--bg-tertiary);
618
+ padding: 0.15rem 0.4rem;
619
+ border-radius: 3px;
620
+ color: var(--accent-hover);
621
+ }
622
+
623
+ .docs-try-it {
624
+ background: var(--bg-tertiary);
625
+ border: 1px solid var(--border);
626
+ border-radius: 8px;
627
+ padding: 1.25rem;
628
+ margin-top: 1rem;
629
+ }
630
+
631
+ .docs-try-it textarea {
632
+ min-height: 100px;
633
+ margin-bottom: 0.75rem;
634
+ }
635
+
636
+ .docs-try-output {
637
+ background: var(--bg-primary);
638
+ border: 1px solid var(--border);
639
+ border-radius: 6px;
640
+ padding: 1rem;
641
+ font-family: 'JetBrains Mono', monospace;
642
+ font-size: 0.8rem;
643
+ color: var(--text-secondary);
644
+ white-space: pre-wrap;
645
+ word-break: break-word;
646
+ max-height: 300px;
647
+ overflow-y: auto;
648
+ display: none;
649
+ }
650
+
651
+ .docs-try-output.visible {
652
+ display: block;
653
+ animation: fadeIn 0.3s ease;
654
+ }
655
+
656
  @media (max-width: 600px) {
657
  .container {
658
  padding: 1rem;
 
685
  <p class="tagline">Identify which AI provider generated a response</p>
686
  </header>
687
 
688
+ <div class="tabs">
689
+ <button class="tab active" data-tab="classify">Classify</button>
690
+ <button class="tab" data-tab="docs">API Docs</button>
691
  </div>
692
+
693
+ <!-- ═══ Classify Tab ═══ -->
694
+ <div class="tab-content active" id="tab-classify">
695
+ <div class="status-indicator">
696
+ <span class="status-dot" id="statusDot"></span>
697
+ <span id="statusText">Connecting to API...</span>
698
+ </div>
699
+
 
 
 
 
 
 
700
  <div class="card">
701
+ <div class="card-label">Paste AI Response</div>
702
+ <textarea id="inputText" placeholder="Paste an AI response here to identify which provider generated it..."></textarea>
703
+ </div>
704
+
705
+ <div class="btn-group">
706
+ <button class="btn btn-primary" id="classifyBtn" disabled>
707
+ <span id="classifyBtnText">Classify</span>
708
+ </button>
709
+ <button class="btn btn-secondary" id="clearBtn">Clear</button>
710
+ </div>
711
+
712
+ <div class="results" id="results">
713
+ <div class="card">
714
+ <div class="card-label">Result</div>
715
+ <div class="result-main">
716
+ <span class="result-provider" id="resultProvider">-</span>
717
+ <span class="result-confidence" id="resultConfidence">-</span>
718
+ </div>
719
+ <div class="result-bar">
720
+ <div class="result-bar-fill" id="resultBar" style="width: 0%"></div>
721
+ </div>
722
+ <ul class="result-list" id="resultList"></ul>
723
+ </div>
724
+
725
+ <div class="correction" id="correction">
726
+ <div class="correction-title">Wrong? Correct the provider to train the model:</div>
727
+ <select id="providerSelect"></select>
728
+ <button class="btn btn-primary" id="trainBtn">Train & Save</button>
729
+ </div>
730
+ </div>
731
+
732
+ <div class="stats" id="stats" style="display: none;">
733
+ <div class="stat">
734
+ <div class="stat-value" id="correctionsCount">0</div>
735
+ <div class="stat-label">Corrections</div>
736
  </div>
737
+ <div class="stat">
738
+ <div class="stat-value" id="sessionCount">0</div>
739
+ <div class="stat-label">Session</div>
740
  </div>
 
741
  </div>
742
 
743
+ <div class="actions" id="actions" style="display: none;">
744
+ <button class="btn btn-secondary" id="exportBtn">Export Trained Model</button>
745
+ <button class="btn btn-secondary" id="resetBtn">Reset Training</button>
 
746
  </div>
747
  </div>
748
+
749
+ <!-- ═══ API Docs Tab ═══ -->
750
+ <div class="tab-content" id="tab-docs">
751
+
752
+ <div class="docs-section">
753
+ <h2>Public Classification API</h2>
754
+ <p>
755
+ AIFinder exposes a free, public endpoint for programmatic classification.
756
+ No API key required.
757
+ </p>
758
+ <div>
759
+ <div class="docs-endpoint">
760
+ <span class="docs-method">POST</span>
761
+ <span class="docs-path">/v1/classify</span>
762
+ </div>
763
+ <span class="docs-badge free">No API Key</span>
764
+ <span class="docs-badge limit">60 req/min</span>
765
+ </div>
766
  </div>
767
+
768
+ <!-- ── Request ── -->
769
+ <div class="docs-section">
770
+ <h2>Request</h2>
771
+ <p>Send a JSON body with <span class="docs-inline-code">Content-Type: application/json</span>.</p>
772
+
773
+ <table class="docs-table">
774
+ <thead>
775
+ <tr><th>Field</th><th>Type</th><th>Required</th><th>Description</th></tr>
776
+ </thead>
777
+ <tbody>
778
+ <tr>
779
+ <td><code>text</code></td>
780
+ <td>string</td>
781
+ <td>Yes</td>
782
+ <td>The AI-generated text to classify (min 20 chars)</td>
783
+ </tr>
784
+ <tr>
785
+ <td><code>top_n</code></td>
786
+ <td>integer</td>
787
+ <td>No</td>
788
+ <td>Number of results to return (default: <strong>5</strong>)</td>
789
+ </tr>
790
+ </tbody>
791
+ </table>
792
+
793
+ <div class="docs-warning">
794
+ <strong>⚠️ Strip thought tags!</strong><br>
795
+ Many reasoning models wrap chain-of-thought in
796
+ <span class="docs-inline-code">&lt;think&gt;…&lt;/think&gt;</span> or
797
+ <span class="docs-inline-code">&lt;thinking&gt;…&lt;/thinking&gt;</span> blocks.
798
+ These confuse the classifier. The API strips them automatically, but you should
799
+ remove them on your side too to save bandwidth.
800
+ </div>
801
+ </div>
802
+
803
+ <!-- ── Response ── -->
804
+ <div class="docs-section">
805
+ <h2>Response</h2>
806
+ <div class="docs-code-block">
807
+ <div class="docs-code-header">
808
+ <span>JSON</span>
809
+ <button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
810
+ </div>
811
+ <pre>{
812
+ "provider": "Anthropic",
813
+ "confidence": 87.42,
814
+ "top_providers": [
815
+ { "name": "Anthropic", "confidence": 87.42 },
816
+ { "name": "OpenAI", "confidence": 6.15 },
817
+ { "name": "Google", "confidence": 3.28 },
818
+ { "name": "xAI", "confidence": 1.74 },
819
+ { "name": "DeepSeek", "confidence": 0.89 }
820
+ ]
821
+ }</pre>
822
+ </div>
823
+
824
+ <table class="docs-table">
825
+ <thead>
826
+ <tr><th>Field</th><th>Type</th><th>Description</th></tr>
827
+ </thead>
828
+ <tbody>
829
+ <tr>
830
+ <td><code>provider</code></td>
831
+ <td>string</td>
832
+ <td>Best-matching provider name</td>
833
+ </tr>
834
+ <tr>
835
+ <td><code>confidence</code></td>
836
+ <td>float</td>
837
+ <td>Confidence % for the top provider</td>
838
+ </tr>
839
+ <tr>
840
+ <td><code>top_providers</code></td>
841
+ <td>array</td>
842
+ <td>Ranked list of <code>{ name, confidence }</code> objects</td>
843
+ </tr>
844
+ </tbody>
845
+ </table>
846
+ </div>
847
+
848
+ <!-- ── Errors ── -->
849
+ <div class="docs-section">
850
+ <h2>Errors</h2>
851
+ <table class="docs-table">
852
+ <thead>
853
+ <tr><th>Status</th><th>Meaning</th></tr>
854
+ </thead>
855
+ <tbody>
856
+ <tr><td><code>400</code></td><td>Missing <code>text</code> field or text shorter than 20 characters</td></tr>
857
+ <tr><td><code>429</code></td><td>Rate limit exceeded (60 requests/minute per IP)</td></tr>
858
+ </tbody>
859
+ </table>
860
+ </div>
861
+
862
+ <!-- ── Code Examples ── -->
863
+ <div class="docs-section">
864
+ <h2>Code Examples</h2>
865
+
866
+ <h3>cURL</h3>
867
+ <div class="docs-code-block">
868
+ <div class="docs-code-header">
869
+ <span>Bash</span>
870
+ <button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
871
+ </div>
872
+ <pre>curl -X POST https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify \
873
+ -H "Content-Type: application/json" \
874
+ -d '{
875
+ "text": "I would be happy to help you with that! Here is a detailed explanation of how neural networks work...",
876
+ "top_n": 5
877
+ }'</pre>
878
+ </div>
879
+
880
+ <h3>Python</h3>
881
+ <div class="docs-code-block">
882
+ <div class="docs-code-header">
883
+ <span>Python</span>
884
+ <button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
885
+ </div>
886
+ <pre>import re
887
+ import requests
888
+
889
+ API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify"
890
+
891
+ def strip_think_tags(text):
892
+ """Remove &lt;think&gt;/&lt;thinking&gt; blocks before classifying."""
893
+ return re.sub(r"&lt;think(?:ing)?&gt;.*?&lt;/think(?:ing)?&gt;",
894
+ "", text, flags=re.DOTALL).strip()
895
+
896
+ text = """I'd be happy to help! Neural networks are
897
+ computational models inspired by the human brain..."""
898
+
899
+ # Strip thought tags first (the API does this too,
900
+ # but saves bandwidth to do it client-side)
901
+ cleaned = strip_think_tags(text)
902
+
903
+ response = requests.post(API_URL, json={
904
+ "text": cleaned,
905
+ "top_n": 5
906
+ })
907
+
908
+ data = response.json()
909
+ print(f"Provider: {data['provider']} ({data['confidence']:.1f}%)")
910
+ for p in data["top_providers"]:
911
+ print(f" {p['name']:&lt;20s} {p['confidence']:5.1f}%")</pre>
912
+ </div>
913
+
914
+ <h3>JavaScript (fetch)</h3>
915
+ <div class="docs-code-block">
916
+ <div class="docs-code-header">
917
+ <span>JavaScript</span>
918
+ <button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
919
+ </div>
920
+ <pre>const API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify";
921
+
922
+ function stripThinkTags(text) {
923
+ return text.replace(/&lt;think(?:ing)?&gt;[\s\S]*?&lt;\/think(?:ing)?&gt;/g, "").trim();
924
+ }
925
+
926
+ async function classify(text, topN = 5) {
927
+ const cleaned = stripThinkTags(text);
928
+ const res = await fetch(API_URL, {
929
+ method: "POST",
930
+ headers: { "Content-Type": "application/json" },
931
+ body: JSON.stringify({ text: cleaned, top_n: topN })
932
+ });
933
+ return res.json();
934
+ }
935
+
936
+ // Usage
937
+ classify("I'd be happy to help you understand...")
938
+ .then(data =&gt; {
939
+ console.log(`Provider: ${data.provider} (${data.confidence}%)`);
940
+ data.top_providers.forEach(p =&gt;
941
+ console.log(` ${p.name}: ${p.confidence}%`)
942
+ );
943
+ });</pre>
944
+ </div>
945
+
946
+ <h3>Node.js</h3>
947
+ <div class="docs-code-block">
948
+ <div class="docs-code-header">
949
+ <span>JavaScript (Node)</span>
950
+ <button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
951
+ </div>
952
+ <pre>const API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify";
953
+
954
+ async function classify(text, topN = 5) {
955
+ const cleaned = text
956
+ .replace(/&lt;think(?:ing)?&gt;[\s\S]*?&lt;\/think(?:ing)?&gt;/g, "")
957
+ .trim();
958
+
959
+ const res = await fetch(API_URL, {
960
+ method: "POST",
961
+ headers: { "Content-Type": "application/json" },
962
+ body: JSON.stringify({ text: cleaned, top_n: topN })
963
+ });
964
+
965
+ if (!res.ok) {
966
+ const err = await res.json();
967
+ throw new Error(err.error || `HTTP ${res.status}`);
968
+ }
969
+ return res.json();
970
+ }
971
+
972
+ // Example
973
+ (async () =&gt; {
974
+ const result = await classify(
975
+ "Let me think about this step by step...",
976
+ 3
977
+ );
978
+ console.log(result);
979
+ })();</pre>
980
+ </div>
981
+ </div>
982
+
983
+ <!-- ── Try It ── -->
984
+ <div class="docs-section">
985
+ <h2>Try It</h2>
986
+ <p>Test the API right here — paste any AI-generated text and hit Send.</p>
987
+ <div class="docs-try-it">
988
+ <textarea id="docsTestInput" placeholder="Paste AI-generated text here..."></textarea>
989
+ <div class="btn-group">
990
+ <button class="btn btn-primary" id="docsTestBtn">Send Request</button>
991
+ </div>
992
+ <div class="docs-try-output" id="docsTestOutput"></div>
993
+ </div>
994
+ </div>
995
+
996
+ <!-- ── Providers ── -->
997
+ <div class="docs-section">
998
+ <h2>Supported Providers</h2>
999
+ <p>The classifier currently supports these providers:</p>
1000
+ <div id="docsProviderList" style="display: flex; flex-wrap: wrap; gap: 0.5rem; margin-top: 0.5rem;"></div>
1001
  </div>
1002
  </div>
1003
+
 
 
 
 
 
1004
  <div class="footer">
1005
  <p>AIFinder &mdash; Train on corrections to improve accuracy</p>
1006
  <p style="margin-top: 0.5rem;">
 
1252
  }
1253
  });
1254
 
1255
+ // ── Tab switching ──
1256
+ document.querySelectorAll('.tab').forEach(tab => {
1257
+ tab.addEventListener('click', () => {
1258
+ document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
1259
+ document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
1260
+ tab.classList.add('active');
1261
+ document.getElementById('tab-' + tab.dataset.tab).classList.add('active');
1262
+ });
1263
+ });
1264
+
1265
+ // ── Copy button for code blocks ──
1266
+ function copyCode(btn) {
1267
+ const pre = btn.closest('.docs-code-block').querySelector('pre');
1268
+ navigator.clipboard.writeText(pre.textContent).then(() => {
1269
+ btn.textContent = 'Copied!';
1270
+ setTimeout(() => { btn.textContent = 'Copy'; }, 1500);
1271
+ });
1272
+ }
1273
+
1274
+ // ── Docs: populate provider badges ──
1275
+ function populateDocsProviders() {
1276
+ const list = document.getElementById('docsProviderList');
1277
+ if (!list || !providers.length) return;
1278
+ list.innerHTML = providers.map(p =>
1279
+ `<span class="docs-inline-code" style="padding:0.3rem 0.75rem;">${p}</span>`
1280
+ ).join('');
1281
+ }
1282
+
1283
+ // ── Docs: "Try It" live tester ──
1284
+ const docsTestBtn = document.getElementById('docsTestBtn');
1285
+ const docsTestInput = document.getElementById('docsTestInput');
1286
+ const docsTestOutput = document.getElementById('docsTestOutput');
1287
+
1288
+ if (docsTestBtn) {
1289
+ docsTestBtn.addEventListener('click', async () => {
1290
+ const text = docsTestInput.value.trim();
1291
+ if (text.length < 20) {
1292
+ docsTestOutput.textContent = '{"error": "Text too short (minimum 20 characters)"}';
1293
+ docsTestOutput.classList.add('visible');
1294
+ return;
1295
+ }
1296
+ docsTestBtn.disabled = true;
1297
+ docsTestBtn.innerHTML = '<span class="loading"></span>';
1298
+ try {
1299
+ const res = await fetch(`${API_BASE}/v1/classify`, {
1300
+ method: 'POST',
1301
+ headers: { 'Content-Type': 'application/json' },
1302
+ body: JSON.stringify({ text, top_n: 5 })
1303
+ });
1304
+ const data = await res.json();
1305
+ docsTestOutput.textContent = JSON.stringify(data, null, 2);
1306
+ } catch (e) {
1307
+ docsTestOutput.textContent = `{"error": "${e.message}"}`;
1308
+ }
1309
+ docsTestOutput.classList.add('visible');
1310
+ docsTestBtn.disabled = false;
1311
+ docsTestBtn.textContent = 'Send Request';
1312
+ });
1313
+ }
1314
+
1315
+ // Hook provider list population into the existing load flow
1316
+ const _origLoadProviders = loadProviders;
1317
+ loadProviders = async function() {
1318
+ await _origLoadProviders();
1319
+ populateDocsProviders();
1320
+ };
1321
+
1322
  checkStatus();
1323
  </script>
1324
  </body>
train.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AIFinder Training Script
3
+ Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model.
4
+
5
+ Usage: python3 train.py
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import time
11
+ import joblib
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import TensorDataset, DataLoader
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.metrics import classification_report
18
+ from sklearn.preprocessing import LabelEncoder
19
+ from sklearn.utils.class_weight import compute_class_weight
20
+
21
+ from config import (
22
+ MODEL_DIR,
23
+ TEST_SIZE,
24
+ RANDOM_STATE,
25
+ HIDDEN_DIM,
26
+ EMBED_DIM,
27
+ DROPOUT,
28
+ BATCH_SIZE,
29
+ EPOCHS,
30
+ LEARNING_RATE,
31
+ WEIGHT_DECAY,
32
+ EARLY_STOP_PATIENCE,
33
+ )
34
+ from data_loader import load_all_data
35
+ from features import FeaturePipeline
36
+ from model import AIFinderNet
37
+
38
+
39
+ def _log(msg, t0=None):
40
+ """Print a timestamped log message, optionally with elapsed time."""
41
+ ts = time.strftime("%H:%M:%S")
42
+ if t0 is not None:
43
+ elapsed = time.time() - t0
44
+ print(f" [{ts}] {msg} ({elapsed:.1f}s)")
45
+ else:
46
+ print(f" [{ts}] {msg}")
47
+
48
+
49
+ def main():
50
+ t_start = time.time()
51
+
52
+ print("=" * 60)
53
+ print("AIFinder Training - Provider Classification")
54
+ print("=" * 60)
55
+
56
+ # ── GPU check ──────────────────────────────────────────────
57
+ if torch.cuda.is_available():
58
+ device = torch.device("cuda")
59
+ gpu_name = torch.cuda.get_device_name(0)
60
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
61
+ _log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
62
+ else:
63
+ device = torch.device("cpu")
64
+ _log("No GPU available, using CPU")
65
+
66
+ # ── Load data ──────────────────────────────────────────────
67
+ _log("Starting data load...")
68
+ t0 = time.time()
69
+ texts, providers, models, _is_ai = load_all_data()
70
+ _log("Data load complete", t0)
71
+
72
+ if len(texts) < 100:
73
+ print("ERROR: Not enough data loaded. Check dataset access.")
74
+ sys.exit(1)
75
+
76
+ # ── Encode labels ──────────────────────────────────────────
77
+ _log("Encoding labels...")
78
+ t0 = time.time()
79
+ provider_enc = LabelEncoder()
80
+ provider_labels = provider_enc.fit_transform(providers)
81
+ num_providers = len(provider_enc.classes_)
82
+ _log(f"Labels encoded — {num_providers} providers", t0)
83
+
84
+ # ── Train/test split ───────────────────────────────────────
85
+ _log("Splitting train/test...")
86
+ t0 = time.time()
87
+ indices = np.arange(len(texts))
88
+ train_idx, test_idx = train_test_split(
89
+ indices,
90
+ test_size=TEST_SIZE,
91
+ random_state=RANDOM_STATE,
92
+ stratify=provider_labels,
93
+ )
94
+ train_texts = [texts[i] for i in train_idx]
95
+ test_texts = [texts[i] for i in test_idx]
96
+ _log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0)
97
+
98
+ # ── Build features ─────────────────────────────────────────
99
+ _log("Building feature pipeline (fit on train)...")
100
+ t0 = time.time()
101
+ pipeline = FeaturePipeline()
102
+ X_train = pipeline.fit_transform(train_texts)
103
+ _log(f"Train features: {X_train.shape}", t0)
104
+
105
+ _log("Transforming test set...")
106
+ t0 = time.time()
107
+ X_test = pipeline.transform(test_texts)
108
+ _log(f"Test features: {X_test.shape}", t0)
109
+
110
+ input_dim = X_train.shape[1]
111
+
112
+ # ── Move to device ─────────────────────────────────────────
113
+ _log(f"Moving data to {device}...")
114
+ t0 = time.time()
115
+ X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device)
116
+ X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device)
117
+ y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device)
118
+ y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device)
119
+ if device.type == "cuda":
120
+ mem_used = torch.cuda.memory_allocated() / 1024**3
121
+ _log(f"GPU memory used: {mem_used:.2f} GB", t0)
122
+ else:
123
+ _log(f"Data on {device}", t0)
124
+
125
+ # ── DataLoaders ────────────────────────────────────────────
126
+ batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE
127
+ train_ds = TensorDataset(X_train_t, y_prov_train)
128
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
129
+ val_ds = TensorDataset(X_test_t, y_prov_test)
130
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
131
+
132
+ # ── Model ────���─────────────────────────────────────────────
133
+ _log("Building model...")
134
+ net = AIFinderNet(
135
+ input_dim=input_dim,
136
+ num_providers=num_providers,
137
+ hidden_dim=HIDDEN_DIM,
138
+ embed_dim=EMBED_DIM,
139
+ dropout=DROPOUT,
140
+ ).to(device)
141
+ n_params = sum(p.numel() for p in net.parameters())
142
+ _log(f"Model: {n_params:,} parameters")
143
+
144
+ # ── Class-weighted loss ────────────────────────────────────
145
+ prov_weights = compute_class_weight(
146
+ "balanced", classes=np.arange(num_providers), y=provider_labels[train_idx]
147
+ )
148
+ prov_criterion = nn.CrossEntropyLoss(
149
+ weight=torch.tensor(prov_weights, dtype=torch.float32).to(device)
150
+ )
151
+
152
+ # ── Optimizer + scheduler ──────────────────────────────────
153
+ optimizer = torch.optim.AdamW(
154
+ net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
155
+ )
156
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
157
+ optimizer,
158
+ max_lr=LEARNING_RATE,
159
+ epochs=EPOCHS,
160
+ steps_per_epoch=len(train_loader),
161
+ )
162
+ use_amp = device.type == "cuda"
163
+ scaler = torch.amp.GradScaler() if use_amp else None
164
+
165
+ # ── Training loop ──────────────────────────────────────────
166
+ _log(
167
+ f"Training for {EPOCHS} epochs, batch_size={batch_size}, "
168
+ f"early_stop_patience={EARLY_STOP_PATIENCE}..."
169
+ )
170
+ t0 = time.time()
171
+
172
+ best_val_loss = float("inf")
173
+ best_state = None
174
+ patience_counter = 0
175
+
176
+ for epoch in range(EPOCHS):
177
+ # ── Train phase ───────────────────────────────────────
178
+ net.train()
179
+ epoch_loss = 0.0
180
+ n_batches = 0
181
+
182
+ for batch_X, batch_prov in train_loader:
183
+ optimizer.zero_grad(set_to_none=True)
184
+ if use_amp:
185
+ with torch.amp.autocast(device_type="cuda"):
186
+ prov_logits = net(batch_X)
187
+ loss = prov_criterion(prov_logits, batch_prov)
188
+ scaler.scale(loss).backward()
189
+ scaler.unscale_(optimizer)
190
+ torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
191
+ scaler.step(optimizer)
192
+ scaler.update()
193
+ else:
194
+ prov_logits = net(batch_X)
195
+ loss = prov_criterion(prov_logits, batch_prov)
196
+ loss.backward()
197
+ torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
198
+ optimizer.step()
199
+ scheduler.step()
200
+ epoch_loss += loss.item()
201
+ n_batches += 1
202
+
203
+ avg_train_loss = epoch_loss / n_batches
204
+
205
+ # ── Validation phase ──────────────────────────────────
206
+ net.eval()
207
+ val_loss = 0.0
208
+ val_batches = 0
209
+ with torch.no_grad():
210
+ for batch_X, batch_prov in val_loader:
211
+ prov_logits = net(batch_X)
212
+ loss = prov_criterion(prov_logits, batch_prov)
213
+ val_loss += loss.item()
214
+ val_batches += 1
215
+ avg_val_loss = val_loss / val_batches
216
+
217
+ # ── Early stopping check ──────────────────────────────
218
+ if avg_val_loss < best_val_loss:
219
+ best_val_loss = avg_val_loss
220
+ best_state = {k: v.clone() for k, v in net.state_dict().items()}
221
+ patience_counter = 0
222
+ else:
223
+ patience_counter += 1
224
+
225
+ # ── Logging ───────────────────────────────────────────
226
+ if (epoch + 1) % 5 == 0 or epoch == 0:
227
+ lr = scheduler.get_last_lr()[0]
228
+ marker = " *" if patience_counter == 0 else ""
229
+ _log(
230
+ f"Epoch {epoch + 1:>3d}/{EPOCHS} "
231
+ f"train={avg_train_loss:.4f} "
232
+ f"val={avg_val_loss:.4f} "
233
+ f"lr={lr:.2e}{marker}"
234
+ )
235
+
236
+ if patience_counter >= EARLY_STOP_PATIENCE:
237
+ _log(
238
+ f"Early stopping at epoch {epoch + 1} "
239
+ f"(best val_loss={best_val_loss:.4f})"
240
+ )
241
+ break
242
+
243
+ # Restore best weights
244
+ if best_state is not None:
245
+ net.load_state_dict(best_state)
246
+ _log(f"Restored best weights (val_loss={best_val_loss:.4f})")
247
+
248
+ _log("Training complete", t0)
249
+
250
+ # ── Evaluate ───────────────────────────────────────────────
251
+ _log("Evaluating...")
252
+ net.eval()
253
+ with torch.no_grad():
254
+ prov_logits = net(X_test_t)
255
+
256
+ prov_preds = prov_logits.argmax(dim=1).cpu().numpy()
257
+ prov_true = y_prov_test.cpu().numpy()
258
+
259
+ print("\n === Provider Classification ===")
260
+ print(
261
+ classification_report(
262
+ prov_true,
263
+ prov_preds,
264
+ target_names=provider_enc.classes_,
265
+ zero_division=0,
266
+ )
267
+ )
268
+
269
+ # ── Save ───────────────────────────────────────────────────
270
+ _log(f"Saving to {MODEL_DIR}/ ...")
271
+ t0 = time.time()
272
+ os.makedirs(MODEL_DIR, exist_ok=True)
273
+
274
+ checkpoint = {
275
+ "input_dim": input_dim,
276
+ "num_providers": num_providers,
277
+ "hidden_dim": HIDDEN_DIM,
278
+ "embed_dim": EMBED_DIM,
279
+ "dropout": DROPOUT,
280
+ "state_dict": net.state_dict(),
281
+ }
282
+ torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
283
+ _log(" Saved classifier.pt")
284
+
285
+ joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
286
+ _log(" Saved feature_pipeline.joblib")
287
+ joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib"))
288
+ _log(" Saved provider_enc.joblib")
289
+
290
+ _log("All artifacts saved", t0)
291
+
292
+ elapsed = time.time() - t_start
293
+ if device.type == "cuda":
294
+ mem_peak = torch.cuda.max_memory_allocated() / 1024**3
295
+ print(f"\n{'=' * 60}")
296
+ print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)")
297
+ print(f"{'=' * 60}")
298
+ else:
299
+ print(f"\n{'=' * 60}")
300
+ print(f"Training complete in {elapsed:.1f}s")
301
+ print(f"{'=' * 60}")
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()