Upload 15 files
Browse files- app.py +78 -0
- example_api.py +89 -0
- models/aifinder_trained.pt +2 -2
- requirements.txt +1 -0
- static/index.html +626 -43
- train.py +305 -0
app.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
"""
|
| 2 |
AIFinder API Server
|
| 3 |
Serves classification and training endpoints for the frontend.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
|
|
|
| 7 |
import sys
|
| 8 |
import json
|
| 9 |
import joblib
|
|
@@ -12,6 +17,8 @@ import torch
|
|
| 12 |
import torch.nn as nn
|
| 13 |
from flask import Flask, request, jsonify, send_from_directory
|
| 14 |
from flask_cors import CORS
|
|
|
|
|
|
|
| 15 |
|
| 16 |
from config import MODEL_DIR
|
| 17 |
from model import AIFinderNet
|
|
@@ -19,6 +26,9 @@ from features import FeaturePipeline
|
|
| 19 |
|
| 20 |
app = Flask(__name__, static_folder="static", static_url_path="")
|
| 21 |
CORS(app)
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
pipeline = None
|
| 24 |
provider_enc = None
|
|
@@ -96,6 +106,74 @@ def classify():
|
|
| 96 |
)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
@app.route("/api/correct", methods=["POST"])
|
| 100 |
def correct():
|
| 101 |
"""Train on a corrected example."""
|
|
|
|
| 1 |
"""
|
| 2 |
AIFinder API Server
|
| 3 |
Serves classification and training endpoints for the frontend.
|
| 4 |
+
|
| 5 |
+
Public API:
|
| 6 |
+
POST /v1/classify — classify text, returns top-N provider predictions.
|
| 7 |
+
No API key required. Rate-limited to 60 requests/minute per IP.
|
| 8 |
"""
|
| 9 |
|
| 10 |
import os
|
| 11 |
+
import re
|
| 12 |
import sys
|
| 13 |
import json
|
| 14 |
import joblib
|
|
|
|
| 17 |
import torch.nn as nn
|
| 18 |
from flask import Flask, request, jsonify, send_from_directory
|
| 19 |
from flask_cors import CORS
|
| 20 |
+
from flask_limiter import Limiter
|
| 21 |
+
from flask_limiter.util import get_remote_address
|
| 22 |
|
| 23 |
from config import MODEL_DIR
|
| 24 |
from model import AIFinderNet
|
|
|
|
| 26 |
|
| 27 |
app = Flask(__name__, static_folder="static", static_url_path="")
|
| 28 |
CORS(app)
|
| 29 |
+
limiter = Limiter(get_remote_address, app=app, default_limits=[])
|
| 30 |
+
|
| 31 |
+
DEFAULT_TOP_N = 5
|
| 32 |
|
| 33 |
pipeline = None
|
| 34 |
provider_enc = None
|
|
|
|
| 106 |
)
|
| 107 |
|
| 108 |
|
| 109 |
+
def _strip_think_tags(text):
|
| 110 |
+
"""Remove <think>…</think> (and <thinking>…</thinking>) blocks from input."""
|
| 111 |
+
text = re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL)
|
| 112 |
+
return text.strip()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@app.route("/v1/classify", methods=["POST"])
|
| 116 |
+
@limiter.limit("60/minute")
|
| 117 |
+
def v1_classify():
|
| 118 |
+
"""Public API — classify text and return top-N provider predictions.
|
| 119 |
+
|
| 120 |
+
Request JSON:
|
| 121 |
+
text (str): The text to classify. Any <think>/<thinking> tags will be
|
| 122 |
+
stripped automatically before classification.
|
| 123 |
+
top_n (int): Number of results to return (default: 5).
|
| 124 |
+
|
| 125 |
+
Response JSON:
|
| 126 |
+
provider (str): Best-matching provider name.
|
| 127 |
+
confidence (float): Confidence % for the top provider.
|
| 128 |
+
top_providers (list): List of {name, confidence} dicts.
|
| 129 |
+
|
| 130 |
+
Rate limit: 60 requests per minute per IP. No API key required.
|
| 131 |
+
|
| 132 |
+
NOTE: If the text you are classifying was produced by a model that emits
|
| 133 |
+
<think> or <thinking> blocks, you should strip those tags BEFORE
|
| 134 |
+
sending the text. This endpoint does it for you automatically, but
|
| 135 |
+
doing it on your side avoids wasting bytes on the wire.
|
| 136 |
+
"""
|
| 137 |
+
data = request.get_json(silent=True)
|
| 138 |
+
if not data or "text" not in data:
|
| 139 |
+
return jsonify({"error": "Request body must be JSON with a 'text' field."}), 400
|
| 140 |
+
|
| 141 |
+
raw_text = data["text"]
|
| 142 |
+
text = _strip_think_tags(raw_text)
|
| 143 |
+
top_n = data.get("top_n", DEFAULT_TOP_N)
|
| 144 |
+
|
| 145 |
+
if not isinstance(top_n, int) or top_n < 1:
|
| 146 |
+
top_n = DEFAULT_TOP_N
|
| 147 |
+
|
| 148 |
+
if len(text) < 20:
|
| 149 |
+
return jsonify({"error": "Text too short (minimum 20 characters after stripping think tags)."}), 400
|
| 150 |
+
|
| 151 |
+
X = pipeline.transform([text])
|
| 152 |
+
X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
prov_logits = net(X_t)
|
| 156 |
+
|
| 157 |
+
prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
|
| 158 |
+
|
| 159 |
+
top_idxs = np.argsort(prov_proba)[::-1][:top_n]
|
| 160 |
+
top_providers = [
|
| 161 |
+
{
|
| 162 |
+
"name": provider_enc.inverse_transform([i])[0],
|
| 163 |
+
"confidence": round(float(prov_proba[i] * 100), 2),
|
| 164 |
+
}
|
| 165 |
+
for i in top_idxs
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
return jsonify(
|
| 169 |
+
{
|
| 170 |
+
"provider": top_providers[0]["name"],
|
| 171 |
+
"confidence": top_providers[0]["confidence"],
|
| 172 |
+
"top_providers": top_providers,
|
| 173 |
+
}
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
@app.route("/api/correct", methods=["POST"])
|
| 178 |
def correct():
|
| 179 |
"""Train on a corrected example."""
|
example_api.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example: Call the AIFinder public API to classify text.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python example_api.py
|
| 7 |
+
|
| 8 |
+
Requirements:
|
| 9 |
+
pip install requests
|
| 10 |
+
|
| 11 |
+
IMPORTANT — Strip <think>/<thinking> tags!
|
| 12 |
+
Many reasoning models wrap chain-of-thought in <think>…</think> or
|
| 13 |
+
<thinking>…</thinking> blocks. These tags confuse the classifier because
|
| 14 |
+
they are NOT part of the model's actual output style. The API strips them
|
| 15 |
+
automatically, but you should also strip them on your side to avoid sending
|
| 16 |
+
unnecessary data.
|
| 17 |
+
|
| 18 |
+
API details:
|
| 19 |
+
POST /v1/classify
|
| 20 |
+
No API key required. Rate limit: 60 requests/minute per IP.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import re
|
| 24 |
+
import json
|
| 25 |
+
import requests
|
| 26 |
+
|
| 27 |
+
# Change this to your server URL (local or HuggingFace Space).
|
| 28 |
+
API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify"
|
| 29 |
+
|
| 30 |
+
# Number of top results to return (default on server is 5).
|
| 31 |
+
TOP_N = 5
|
| 32 |
+
|
| 33 |
+
EXAMPLE_TEXT = """\
|
| 34 |
+
I'd be happy to help you understand how neural networks work!
|
| 35 |
+
|
| 36 |
+
Neural networks are computational models inspired by the human brain. They consist of layers of interconnected nodes (neurons) that process information. Here's a breakdown:
|
| 37 |
+
|
| 38 |
+
1. **Input Layer**: Receives the raw data
|
| 39 |
+
2. **Hidden Layers**: Process and transform the data through weighted connections
|
| 40 |
+
3. **Output Layer**: Produces the final prediction
|
| 41 |
+
|
| 42 |
+
Each connection has a weight, and each neuron has a bias. During training, the network adjusts these weights using backpropagation to minimize the difference between predicted and actual outputs.
|
| 43 |
+
|
| 44 |
+
The key insight is that by stacking multiple layers, neural networks can learn increasingly abstract representations of data, enabling them to solve complex tasks like image recognition, natural language processing, and more.
|
| 45 |
+
|
| 46 |
+
Would you like me to dive deeper into any specific aspect?
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def strip_think_tags(text: str) -> str:
|
| 51 |
+
"""Remove <think>…</think> and <thinking>…</thinking> blocks."""
|
| 52 |
+
return re.sub(
|
| 53 |
+
r"<think(?:ing)?>.*?</think(?:ing)?>", "", text, flags=re.DOTALL
|
| 54 |
+
).strip()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def classify(text: str, top_n: int = TOP_N) -> dict:
|
| 58 |
+
"""Send text to the AIFinder API and return the JSON response."""
|
| 59 |
+
cleaned = strip_think_tags(text)
|
| 60 |
+
resp = requests.post(
|
| 61 |
+
API_URL,
|
| 62 |
+
json={"text": cleaned, "top_n": top_n},
|
| 63 |
+
timeout=30,
|
| 64 |
+
)
|
| 65 |
+
resp.raise_for_status()
|
| 66 |
+
return resp.json()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
print("AIFinder API Example")
|
| 71 |
+
print("=" * 50)
|
| 72 |
+
print(f"Endpoint : {API_URL}")
|
| 73 |
+
print(f"Top-N : {TOP_N}")
|
| 74 |
+
print()
|
| 75 |
+
|
| 76 |
+
result = classify(EXAMPLE_TEXT)
|
| 77 |
+
|
| 78 |
+
print(f"Best match : {result['provider']} ({result['confidence']:.1f}%)")
|
| 79 |
+
print()
|
| 80 |
+
print("Top providers:")
|
| 81 |
+
for entry in result["top_providers"]:
|
| 82 |
+
bar = "█" * int(entry["confidence"] / 5) + "░" * (
|
| 83 |
+
20 - int(entry["confidence"] / 5)
|
| 84 |
+
)
|
| 85 |
+
print(f" {entry['name']:.<25s} {entry['confidence']:5.1f}% {bar}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
models/aifinder_trained.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:daeb1c62c7ca34fd45b697fb0240608555aa1617e45baa8d81620a69272bbaaf
|
| 3 |
+
size 165033081
|
requirements.txt
CHANGED
|
@@ -10,3 +10,4 @@ pandas>=2.0
|
|
| 10 |
huggingface_hub>=0.23.0
|
| 11 |
flask>=3.0
|
| 12 |
flask-cors>=4.0
|
|
|
|
|
|
| 10 |
huggingface_hub>=0.23.0
|
| 11 |
flask>=3.0
|
| 12 |
flask-cors>=4.0
|
| 13 |
+
flask-limiter>=3.0
|
static/index.html
CHANGED
|
@@ -399,7 +399,260 @@
|
|
| 399 |
margin-bottom: 1rem;
|
| 400 |
opacity: 0.5;
|
| 401 |
}
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
@media (max-width: 600px) {
|
| 404 |
.container {
|
| 405 |
padding: 1rem;
|
|
@@ -432,59 +685,322 @@
|
|
| 432 |
<p class="tagline">Identify which AI provider generated a response</p>
|
| 433 |
</header>
|
| 434 |
|
| 435 |
-
<div class="
|
| 436 |
-
<
|
| 437 |
-
<
|
| 438 |
</div>
|
| 439 |
-
|
| 440 |
-
<
|
| 441 |
-
|
| 442 |
-
<
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
<span id="classifyBtnText">Classify</span>
|
| 448 |
-
</button>
|
| 449 |
-
<button class="btn btn-secondary" id="clearBtn">Clear</button>
|
| 450 |
-
</div>
|
| 451 |
-
|
| 452 |
-
<div class="results" id="results">
|
| 453 |
<div class="card">
|
| 454 |
-
<div class="card-label">
|
| 455 |
-
<
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
</div>
|
| 459 |
-
<div class="
|
| 460 |
-
<div class="
|
|
|
|
| 461 |
</div>
|
| 462 |
-
<ul class="result-list" id="resultList"></ul>
|
| 463 |
</div>
|
| 464 |
|
| 465 |
-
<div class="
|
| 466 |
-
<
|
| 467 |
-
<
|
| 468 |
-
<button class="btn btn-primary" id="trainBtn">Train & Save</button>
|
| 469 |
</div>
|
| 470 |
</div>
|
| 471 |
-
|
| 472 |
-
<
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
</div>
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
</div>
|
| 481 |
</div>
|
| 482 |
-
|
| 483 |
-
<div class="actions" id="actions" style="display: none;">
|
| 484 |
-
<button class="btn btn-secondary" id="exportBtn">Export Trained Model</button>
|
| 485 |
-
<button class="btn btn-secondary" id="resetBtn">Reset Training</button>
|
| 486 |
-
</div>
|
| 487 |
-
|
| 488 |
<div class="footer">
|
| 489 |
<p>AIFinder — Train on corrections to improve accuracy</p>
|
| 490 |
<p style="margin-top: 0.5rem;">
|
|
@@ -736,6 +1252,73 @@
|
|
| 736 |
}
|
| 737 |
});
|
| 738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
checkStatus();
|
| 740 |
</script>
|
| 741 |
</body>
|
|
|
|
| 399 |
margin-bottom: 1rem;
|
| 400 |
opacity: 0.5;
|
| 401 |
}
|
| 402 |
+
|
| 403 |
+
/* ── Tabs ── */
|
| 404 |
+
.tabs {
|
| 405 |
+
display: flex;
|
| 406 |
+
gap: 0;
|
| 407 |
+
margin-bottom: 2rem;
|
| 408 |
+
border-bottom: 1px solid var(--border);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
.tab {
|
| 412 |
+
padding: 0.75rem 1.5rem;
|
| 413 |
+
font-family: 'Outfit', sans-serif;
|
| 414 |
+
font-size: 0.9rem;
|
| 415 |
+
font-weight: 500;
|
| 416 |
+
color: var(--text-muted);
|
| 417 |
+
background: none;
|
| 418 |
+
border: none;
|
| 419 |
+
border-bottom: 2px solid transparent;
|
| 420 |
+
cursor: pointer;
|
| 421 |
+
transition: all 0.2s ease;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
.tab:hover {
|
| 425 |
+
color: var(--text-secondary);
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
.tab.active {
|
| 429 |
+
color: var(--accent);
|
| 430 |
+
border-bottom-color: var(--accent);
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
.tab-content {
|
| 434 |
+
display: none;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
.tab-content.active {
|
| 438 |
+
display: block;
|
| 439 |
+
animation: fadeIn 0.3s ease;
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/* ── API Docs ── */
|
| 443 |
+
.docs-section {
|
| 444 |
+
margin-bottom: 2rem;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
.docs-section h2 {
|
| 448 |
+
font-size: 1.25rem;
|
| 449 |
+
font-weight: 600;
|
| 450 |
+
margin-bottom: 0.75rem;
|
| 451 |
+
color: var(--text-primary);
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
.docs-section h3 {
|
| 455 |
+
font-size: 1rem;
|
| 456 |
+
font-weight: 500;
|
| 457 |
+
margin-top: 1.25rem;
|
| 458 |
+
margin-bottom: 0.5rem;
|
| 459 |
+
color: var(--text-secondary);
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
.docs-section p {
|
| 463 |
+
color: var(--text-secondary);
|
| 464 |
+
font-size: 0.9rem;
|
| 465 |
+
margin-bottom: 0.75rem;
|
| 466 |
+
line-height: 1.7;
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
.docs-endpoint {
|
| 470 |
+
display: inline-flex;
|
| 471 |
+
align-items: center;
|
| 472 |
+
gap: 0.5rem;
|
| 473 |
+
background: var(--bg-tertiary);
|
| 474 |
+
border: 1px solid var(--border);
|
| 475 |
+
border-radius: 6px;
|
| 476 |
+
padding: 0.5rem 1rem;
|
| 477 |
+
margin-bottom: 1rem;
|
| 478 |
+
font-family: 'JetBrains Mono', monospace;
|
| 479 |
+
font-size: 0.85rem;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
.docs-method {
|
| 483 |
+
color: var(--success);
|
| 484 |
+
font-weight: 600;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
.docs-path {
|
| 488 |
+
color: var(--text-primary);
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
.docs-badge {
|
| 492 |
+
display: inline-block;
|
| 493 |
+
font-size: 0.7rem;
|
| 494 |
+
font-weight: 600;
|
| 495 |
+
text-transform: uppercase;
|
| 496 |
+
letter-spacing: 0.05em;
|
| 497 |
+
padding: 0.2rem 0.6rem;
|
| 498 |
+
border-radius: 4px;
|
| 499 |
+
margin-left: 0.5rem;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
.docs-badge.free {
|
| 503 |
+
background: var(--success-muted);
|
| 504 |
+
color: var(--success);
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
.docs-badge.limit {
|
| 508 |
+
background: var(--accent-muted);
|
| 509 |
+
color: var(--accent-hover);
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
.docs-code-block {
|
| 513 |
+
position: relative;
|
| 514 |
+
background: var(--bg-tertiary);
|
| 515 |
+
border: 1px solid var(--border);
|
| 516 |
+
border-radius: 8px;
|
| 517 |
+
margin-bottom: 1rem;
|
| 518 |
+
overflow: hidden;
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
.docs-code-header {
|
| 522 |
+
display: flex;
|
| 523 |
+
align-items: center;
|
| 524 |
+
justify-content: space-between;
|
| 525 |
+
padding: 0.5rem 1rem;
|
| 526 |
+
background: var(--bg-elevated);
|
| 527 |
+
border-bottom: 1px solid var(--border);
|
| 528 |
+
font-size: 0.75rem;
|
| 529 |
+
color: var(--text-muted);
|
| 530 |
+
text-transform: uppercase;
|
| 531 |
+
letter-spacing: 0.05em;
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
.docs-copy-btn {
|
| 535 |
+
background: none;
|
| 536 |
+
border: 1px solid var(--border);
|
| 537 |
+
border-radius: 4px;
|
| 538 |
+
color: var(--text-muted);
|
| 539 |
+
font-size: 0.7rem;
|
| 540 |
+
padding: 0.2rem 0.5rem;
|
| 541 |
+
cursor: pointer;
|
| 542 |
+
font-family: 'Outfit', sans-serif;
|
| 543 |
+
transition: all 0.2s ease;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
.docs-copy-btn:hover {
|
| 547 |
+
color: var(--text-primary);
|
| 548 |
+
border-color: var(--border-light);
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
.docs-code-block pre {
|
| 552 |
+
padding: 1rem;
|
| 553 |
+
overflow-x: auto;
|
| 554 |
+
font-family: 'JetBrains Mono', monospace;
|
| 555 |
+
font-size: 0.8rem;
|
| 556 |
+
line-height: 1.6;
|
| 557 |
+
color: var(--text-primary);
|
| 558 |
+
margin: 0;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
.docs-table {
|
| 562 |
+
width: 100%;
|
| 563 |
+
border-collapse: collapse;
|
| 564 |
+
font-size: 0.85rem;
|
| 565 |
+
margin-bottom: 1rem;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
.docs-table th {
|
| 569 |
+
text-align: left;
|
| 570 |
+
padding: 0.6rem 0.75rem;
|
| 571 |
+
background: var(--bg-elevated);
|
| 572 |
+
color: var(--text-secondary);
|
| 573 |
+
font-weight: 500;
|
| 574 |
+
border-bottom: 1px solid var(--border);
|
| 575 |
+
font-size: 0.75rem;
|
| 576 |
+
text-transform: uppercase;
|
| 577 |
+
letter-spacing: 0.05em;
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
.docs-table td {
|
| 581 |
+
padding: 0.6rem 0.75rem;
|
| 582 |
+
border-bottom: 1px solid var(--border);
|
| 583 |
+
color: var(--text-secondary);
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
.docs-table tr:last-child td {
|
| 587 |
+
border-bottom: none;
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
.docs-table code {
|
| 591 |
+
font-family: 'JetBrains Mono', monospace;
|
| 592 |
+
font-size: 0.8rem;
|
| 593 |
+
background: var(--bg-tertiary);
|
| 594 |
+
padding: 0.15rem 0.4rem;
|
| 595 |
+
border-radius: 3px;
|
| 596 |
+
color: var(--accent-hover);
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
.docs-warning {
|
| 600 |
+
background: rgba(232, 93, 4, 0.08);
|
| 601 |
+
border: 1px solid var(--accent-muted);
|
| 602 |
+
border-radius: 8px;
|
| 603 |
+
padding: 1rem 1.25rem;
|
| 604 |
+
margin-bottom: 1rem;
|
| 605 |
+
font-size: 0.85rem;
|
| 606 |
+
color: var(--text-secondary);
|
| 607 |
+
line-height: 1.7;
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
.docs-warning strong {
|
| 611 |
+
color: var(--accent-hover);
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
.docs-inline-code {
|
| 615 |
+
font-family: 'JetBrains Mono', monospace;
|
| 616 |
+
font-size: 0.8rem;
|
| 617 |
+
background: var(--bg-tertiary);
|
| 618 |
+
padding: 0.15rem 0.4rem;
|
| 619 |
+
border-radius: 3px;
|
| 620 |
+
color: var(--accent-hover);
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
.docs-try-it {
|
| 624 |
+
background: var(--bg-tertiary);
|
| 625 |
+
border: 1px solid var(--border);
|
| 626 |
+
border-radius: 8px;
|
| 627 |
+
padding: 1.25rem;
|
| 628 |
+
margin-top: 1rem;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
.docs-try-it textarea {
|
| 632 |
+
min-height: 100px;
|
| 633 |
+
margin-bottom: 0.75rem;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
.docs-try-output {
|
| 637 |
+
background: var(--bg-primary);
|
| 638 |
+
border: 1px solid var(--border);
|
| 639 |
+
border-radius: 6px;
|
| 640 |
+
padding: 1rem;
|
| 641 |
+
font-family: 'JetBrains Mono', monospace;
|
| 642 |
+
font-size: 0.8rem;
|
| 643 |
+
color: var(--text-secondary);
|
| 644 |
+
white-space: pre-wrap;
|
| 645 |
+
word-break: break-word;
|
| 646 |
+
max-height: 300px;
|
| 647 |
+
overflow-y: auto;
|
| 648 |
+
display: none;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
.docs-try-output.visible {
|
| 652 |
+
display: block;
|
| 653 |
+
animation: fadeIn 0.3s ease;
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
@media (max-width: 600px) {
|
| 657 |
.container {
|
| 658 |
padding: 1rem;
|
|
|
|
| 685 |
<p class="tagline">Identify which AI provider generated a response</p>
|
| 686 |
</header>
|
| 687 |
|
| 688 |
+
<div class="tabs">
|
| 689 |
+
<button class="tab active" data-tab="classify">Classify</button>
|
| 690 |
+
<button class="tab" data-tab="docs">API Docs</button>
|
| 691 |
</div>
|
| 692 |
+
|
| 693 |
+
<!-- ═══ Classify Tab ═══ -->
|
| 694 |
+
<div class="tab-content active" id="tab-classify">
|
| 695 |
+
<div class="status-indicator">
|
| 696 |
+
<span class="status-dot" id="statusDot"></span>
|
| 697 |
+
<span id="statusText">Connecting to API...</span>
|
| 698 |
+
</div>
|
| 699 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
<div class="card">
|
| 701 |
+
<div class="card-label">Paste AI Response</div>
|
| 702 |
+
<textarea id="inputText" placeholder="Paste an AI response here to identify which provider generated it..."></textarea>
|
| 703 |
+
</div>
|
| 704 |
+
|
| 705 |
+
<div class="btn-group">
|
| 706 |
+
<button class="btn btn-primary" id="classifyBtn" disabled>
|
| 707 |
+
<span id="classifyBtnText">Classify</span>
|
| 708 |
+
</button>
|
| 709 |
+
<button class="btn btn-secondary" id="clearBtn">Clear</button>
|
| 710 |
+
</div>
|
| 711 |
+
|
| 712 |
+
<div class="results" id="results">
|
| 713 |
+
<div class="card">
|
| 714 |
+
<div class="card-label">Result</div>
|
| 715 |
+
<div class="result-main">
|
| 716 |
+
<span class="result-provider" id="resultProvider">-</span>
|
| 717 |
+
<span class="result-confidence" id="resultConfidence">-</span>
|
| 718 |
+
</div>
|
| 719 |
+
<div class="result-bar">
|
| 720 |
+
<div class="result-bar-fill" id="resultBar" style="width: 0%"></div>
|
| 721 |
+
</div>
|
| 722 |
+
<ul class="result-list" id="resultList"></ul>
|
| 723 |
+
</div>
|
| 724 |
+
|
| 725 |
+
<div class="correction" id="correction">
|
| 726 |
+
<div class="correction-title">Wrong? Correct the provider to train the model:</div>
|
| 727 |
+
<select id="providerSelect"></select>
|
| 728 |
+
<button class="btn btn-primary" id="trainBtn">Train & Save</button>
|
| 729 |
+
</div>
|
| 730 |
+
</div>
|
| 731 |
+
|
| 732 |
+
<div class="stats" id="stats" style="display: none;">
|
| 733 |
+
<div class="stat">
|
| 734 |
+
<div class="stat-value" id="correctionsCount">0</div>
|
| 735 |
+
<div class="stat-label">Corrections</div>
|
| 736 |
</div>
|
| 737 |
+
<div class="stat">
|
| 738 |
+
<div class="stat-value" id="sessionCount">0</div>
|
| 739 |
+
<div class="stat-label">Session</div>
|
| 740 |
</div>
|
|
|
|
| 741 |
</div>
|
| 742 |
|
| 743 |
+
<div class="actions" id="actions" style="display: none;">
|
| 744 |
+
<button class="btn btn-secondary" id="exportBtn">Export Trained Model</button>
|
| 745 |
+
<button class="btn btn-secondary" id="resetBtn">Reset Training</button>
|
|
|
|
| 746 |
</div>
|
| 747 |
</div>
|
| 748 |
+
|
| 749 |
+
<!-- ═══ API Docs Tab ═══ -->
|
| 750 |
+
<div class="tab-content" id="tab-docs">
|
| 751 |
+
|
| 752 |
+
<div class="docs-section">
|
| 753 |
+
<h2>Public Classification API</h2>
|
| 754 |
+
<p>
|
| 755 |
+
AIFinder exposes a free, public endpoint for programmatic classification.
|
| 756 |
+
No API key required.
|
| 757 |
+
</p>
|
| 758 |
+
<div>
|
| 759 |
+
<div class="docs-endpoint">
|
| 760 |
+
<span class="docs-method">POST</span>
|
| 761 |
+
<span class="docs-path">/v1/classify</span>
|
| 762 |
+
</div>
|
| 763 |
+
<span class="docs-badge free">No API Key</span>
|
| 764 |
+
<span class="docs-badge limit">60 req/min</span>
|
| 765 |
+
</div>
|
| 766 |
</div>
|
| 767 |
+
|
| 768 |
+
<!-- ── Request ── -->
|
| 769 |
+
<div class="docs-section">
|
| 770 |
+
<h2>Request</h2>
|
| 771 |
+
<p>Send a JSON body with <span class="docs-inline-code">Content-Type: application/json</span>.</p>
|
| 772 |
+
|
| 773 |
+
<table class="docs-table">
|
| 774 |
+
<thead>
|
| 775 |
+
<tr><th>Field</th><th>Type</th><th>Required</th><th>Description</th></tr>
|
| 776 |
+
</thead>
|
| 777 |
+
<tbody>
|
| 778 |
+
<tr>
|
| 779 |
+
<td><code>text</code></td>
|
| 780 |
+
<td>string</td>
|
| 781 |
+
<td>Yes</td>
|
| 782 |
+
<td>The AI-generated text to classify (min 20 chars)</td>
|
| 783 |
+
</tr>
|
| 784 |
+
<tr>
|
| 785 |
+
<td><code>top_n</code></td>
|
| 786 |
+
<td>integer</td>
|
| 787 |
+
<td>No</td>
|
| 788 |
+
<td>Number of results to return (default: <strong>5</strong>)</td>
|
| 789 |
+
</tr>
|
| 790 |
+
</tbody>
|
| 791 |
+
</table>
|
| 792 |
+
|
| 793 |
+
<div class="docs-warning">
|
| 794 |
+
<strong>⚠️ Strip thought tags!</strong><br>
|
| 795 |
+
Many reasoning models wrap chain-of-thought in
|
| 796 |
+
<span class="docs-inline-code"><think>…</think></span> or
|
| 797 |
+
<span class="docs-inline-code"><thinking>…</thinking></span> blocks.
|
| 798 |
+
These confuse the classifier. The API strips them automatically, but you should
|
| 799 |
+
remove them on your side too to save bandwidth.
|
| 800 |
+
</div>
|
| 801 |
+
</div>
|
| 802 |
+
|
| 803 |
+
<!-- ── Response ── -->
|
| 804 |
+
<div class="docs-section">
|
| 805 |
+
<h2>Response</h2>
|
| 806 |
+
<div class="docs-code-block">
|
| 807 |
+
<div class="docs-code-header">
|
| 808 |
+
<span>JSON</span>
|
| 809 |
+
<button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
|
| 810 |
+
</div>
|
| 811 |
+
<pre>{
|
| 812 |
+
"provider": "Anthropic",
|
| 813 |
+
"confidence": 87.42,
|
| 814 |
+
"top_providers": [
|
| 815 |
+
{ "name": "Anthropic", "confidence": 87.42 },
|
| 816 |
+
{ "name": "OpenAI", "confidence": 6.15 },
|
| 817 |
+
{ "name": "Google", "confidence": 3.28 },
|
| 818 |
+
{ "name": "xAI", "confidence": 1.74 },
|
| 819 |
+
{ "name": "DeepSeek", "confidence": 0.89 }
|
| 820 |
+
]
|
| 821 |
+
}</pre>
|
| 822 |
+
</div>
|
| 823 |
+
|
| 824 |
+
<table class="docs-table">
|
| 825 |
+
<thead>
|
| 826 |
+
<tr><th>Field</th><th>Type</th><th>Description</th></tr>
|
| 827 |
+
</thead>
|
| 828 |
+
<tbody>
|
| 829 |
+
<tr>
|
| 830 |
+
<td><code>provider</code></td>
|
| 831 |
+
<td>string</td>
|
| 832 |
+
<td>Best-matching provider name</td>
|
| 833 |
+
</tr>
|
| 834 |
+
<tr>
|
| 835 |
+
<td><code>confidence</code></td>
|
| 836 |
+
<td>float</td>
|
| 837 |
+
<td>Confidence % for the top provider</td>
|
| 838 |
+
</tr>
|
| 839 |
+
<tr>
|
| 840 |
+
<td><code>top_providers</code></td>
|
| 841 |
+
<td>array</td>
|
| 842 |
+
<td>Ranked list of <code>{ name, confidence }</code> objects</td>
|
| 843 |
+
</tr>
|
| 844 |
+
</tbody>
|
| 845 |
+
</table>
|
| 846 |
+
</div>
|
| 847 |
+
|
| 848 |
+
<!-- ── Errors ── -->
|
| 849 |
+
<div class="docs-section">
|
| 850 |
+
<h2>Errors</h2>
|
| 851 |
+
<table class="docs-table">
|
| 852 |
+
<thead>
|
| 853 |
+
<tr><th>Status</th><th>Meaning</th></tr>
|
| 854 |
+
</thead>
|
| 855 |
+
<tbody>
|
| 856 |
+
<tr><td><code>400</code></td><td>Missing <code>text</code> field or text shorter than 20 characters</td></tr>
|
| 857 |
+
<tr><td><code>429</code></td><td>Rate limit exceeded (60 requests/minute per IP)</td></tr>
|
| 858 |
+
</tbody>
|
| 859 |
+
</table>
|
| 860 |
+
</div>
|
| 861 |
+
|
| 862 |
+
<!-- ── Code Examples ── -->
|
| 863 |
+
<div class="docs-section">
|
| 864 |
+
<h2>Code Examples</h2>
|
| 865 |
+
|
| 866 |
+
<h3>cURL</h3>
|
| 867 |
+
<div class="docs-code-block">
|
| 868 |
+
<div class="docs-code-header">
|
| 869 |
+
<span>Bash</span>
|
| 870 |
+
<button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
|
| 871 |
+
</div>
|
| 872 |
+
<pre>curl -X POST https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify \
|
| 873 |
+
-H "Content-Type: application/json" \
|
| 874 |
+
-d '{
|
| 875 |
+
"text": "I would be happy to help you with that! Here is a detailed explanation of how neural networks work...",
|
| 876 |
+
"top_n": 5
|
| 877 |
+
}'</pre>
|
| 878 |
+
</div>
|
| 879 |
+
|
| 880 |
+
<h3>Python</h3>
|
| 881 |
+
<div class="docs-code-block">
|
| 882 |
+
<div class="docs-code-header">
|
| 883 |
+
<span>Python</span>
|
| 884 |
+
<button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
|
| 885 |
+
</div>
|
| 886 |
+
<pre>import re
|
| 887 |
+
import requests
|
| 888 |
+
|
| 889 |
+
API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify"
|
| 890 |
+
|
| 891 |
+
def strip_think_tags(text):
|
| 892 |
+
"""Remove <think>/<thinking> blocks before classifying."""
|
| 893 |
+
return re.sub(r"<think(?:ing)?>.*?</think(?:ing)?>",
|
| 894 |
+
"", text, flags=re.DOTALL).strip()
|
| 895 |
+
|
| 896 |
+
text = """I'd be happy to help! Neural networks are
|
| 897 |
+
computational models inspired by the human brain..."""
|
| 898 |
+
|
| 899 |
+
# Strip thought tags first (the API does this too,
|
| 900 |
+
# but saves bandwidth to do it client-side)
|
| 901 |
+
cleaned = strip_think_tags(text)
|
| 902 |
+
|
| 903 |
+
response = requests.post(API_URL, json={
|
| 904 |
+
"text": cleaned,
|
| 905 |
+
"top_n": 5
|
| 906 |
+
})
|
| 907 |
+
|
| 908 |
+
data = response.json()
|
| 909 |
+
print(f"Provider: {data['provider']} ({data['confidence']:.1f}%)")
|
| 910 |
+
for p in data["top_providers"]:
|
| 911 |
+
print(f" {p['name']:<20s} {p['confidence']:5.1f}%")</pre>
|
| 912 |
+
</div>
|
| 913 |
+
|
| 914 |
+
<h3>JavaScript (fetch)</h3>
|
| 915 |
+
<div class="docs-code-block">
|
| 916 |
+
<div class="docs-code-header">
|
| 917 |
+
<span>JavaScript</span>
|
| 918 |
+
<button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
|
| 919 |
+
</div>
|
| 920 |
+
<pre>const API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify";
|
| 921 |
+
|
| 922 |
+
function stripThinkTags(text) {
|
| 923 |
+
return text.replace(/<think(?:ing)?>[\s\S]*?<\/think(?:ing)?>/g, "").trim();
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
async function classify(text, topN = 5) {
|
| 927 |
+
const cleaned = stripThinkTags(text);
|
| 928 |
+
const res = await fetch(API_URL, {
|
| 929 |
+
method: "POST",
|
| 930 |
+
headers: { "Content-Type": "application/json" },
|
| 931 |
+
body: JSON.stringify({ text: cleaned, top_n: topN })
|
| 932 |
+
});
|
| 933 |
+
return res.json();
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
// Usage
|
| 937 |
+
classify("I'd be happy to help you understand...")
|
| 938 |
+
.then(data => {
|
| 939 |
+
console.log(`Provider: ${data.provider} (${data.confidence}%)`);
|
| 940 |
+
data.top_providers.forEach(p =>
|
| 941 |
+
console.log(` ${p.name}: ${p.confidence}%`)
|
| 942 |
+
);
|
| 943 |
+
});</pre>
|
| 944 |
+
</div>
|
| 945 |
+
|
| 946 |
+
<h3>Node.js</h3>
|
| 947 |
+
<div class="docs-code-block">
|
| 948 |
+
<div class="docs-code-header">
|
| 949 |
+
<span>JavaScript (Node)</span>
|
| 950 |
+
<button class="docs-copy-btn" onclick="copyCode(this)">Copy</button>
|
| 951 |
+
</div>
|
| 952 |
+
<pre>const API_URL = "https://huggingface.co/spaces/CompactAI/AIFinder/v1/classify";
|
| 953 |
+
|
| 954 |
+
async function classify(text, topN = 5) {
|
| 955 |
+
const cleaned = text
|
| 956 |
+
.replace(/<think(?:ing)?>[\s\S]*?<\/think(?:ing)?>/g, "")
|
| 957 |
+
.trim();
|
| 958 |
+
|
| 959 |
+
const res = await fetch(API_URL, {
|
| 960 |
+
method: "POST",
|
| 961 |
+
headers: { "Content-Type": "application/json" },
|
| 962 |
+
body: JSON.stringify({ text: cleaned, top_n: topN })
|
| 963 |
+
});
|
| 964 |
+
|
| 965 |
+
if (!res.ok) {
|
| 966 |
+
const err = await res.json();
|
| 967 |
+
throw new Error(err.error || `HTTP ${res.status}`);
|
| 968 |
+
}
|
| 969 |
+
return res.json();
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
// Example
|
| 973 |
+
(async () => {
|
| 974 |
+
const result = await classify(
|
| 975 |
+
"Let me think about this step by step...",
|
| 976 |
+
3
|
| 977 |
+
);
|
| 978 |
+
console.log(result);
|
| 979 |
+
})();</pre>
|
| 980 |
+
</div>
|
| 981 |
+
</div>
|
| 982 |
+
|
| 983 |
+
<!-- ── Try It ── -->
|
| 984 |
+
<div class="docs-section">
|
| 985 |
+
<h2>Try It</h2>
|
| 986 |
+
<p>Test the API right here — paste any AI-generated text and hit Send.</p>
|
| 987 |
+
<div class="docs-try-it">
|
| 988 |
+
<textarea id="docsTestInput" placeholder="Paste AI-generated text here..."></textarea>
|
| 989 |
+
<div class="btn-group">
|
| 990 |
+
<button class="btn btn-primary" id="docsTestBtn">Send Request</button>
|
| 991 |
+
</div>
|
| 992 |
+
<div class="docs-try-output" id="docsTestOutput"></div>
|
| 993 |
+
</div>
|
| 994 |
+
</div>
|
| 995 |
+
|
| 996 |
+
<!-- ── Providers ── -->
|
| 997 |
+
<div class="docs-section">
|
| 998 |
+
<h2>Supported Providers</h2>
|
| 999 |
+
<p>The classifier currently supports these providers:</p>
|
| 1000 |
+
<div id="docsProviderList" style="display: flex; flex-wrap: wrap; gap: 0.5rem; margin-top: 0.5rem;"></div>
|
| 1001 |
</div>
|
| 1002 |
</div>
|
| 1003 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
<div class="footer">
|
| 1005 |
<p>AIFinder — Train on corrections to improve accuracy</p>
|
| 1006 |
<p style="margin-top: 0.5rem;">
|
|
|
|
| 1252 |
}
|
| 1253 |
});
|
| 1254 |
|
| 1255 |
+
// ── Tab switching ──
|
| 1256 |
+
document.querySelectorAll('.tab').forEach(tab => {
|
| 1257 |
+
tab.addEventListener('click', () => {
|
| 1258 |
+
document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
|
| 1259 |
+
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
|
| 1260 |
+
tab.classList.add('active');
|
| 1261 |
+
document.getElementById('tab-' + tab.dataset.tab).classList.add('active');
|
| 1262 |
+
});
|
| 1263 |
+
});
|
| 1264 |
+
|
| 1265 |
+
// ── Copy button for code blocks ──
|
| 1266 |
+
function copyCode(btn) {
|
| 1267 |
+
const pre = btn.closest('.docs-code-block').querySelector('pre');
|
| 1268 |
+
navigator.clipboard.writeText(pre.textContent).then(() => {
|
| 1269 |
+
btn.textContent = 'Copied!';
|
| 1270 |
+
setTimeout(() => { btn.textContent = 'Copy'; }, 1500);
|
| 1271 |
+
});
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
// ── Docs: populate provider badges ──
|
| 1275 |
+
function populateDocsProviders() {
|
| 1276 |
+
const list = document.getElementById('docsProviderList');
|
| 1277 |
+
if (!list || !providers.length) return;
|
| 1278 |
+
list.innerHTML = providers.map(p =>
|
| 1279 |
+
`<span class="docs-inline-code" style="padding:0.3rem 0.75rem;">${p}</span>`
|
| 1280 |
+
).join('');
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
// ── Docs: "Try It" live tester ──
|
| 1284 |
+
const docsTestBtn = document.getElementById('docsTestBtn');
|
| 1285 |
+
const docsTestInput = document.getElementById('docsTestInput');
|
| 1286 |
+
const docsTestOutput = document.getElementById('docsTestOutput');
|
| 1287 |
+
|
| 1288 |
+
if (docsTestBtn) {
|
| 1289 |
+
docsTestBtn.addEventListener('click', async () => {
|
| 1290 |
+
const text = docsTestInput.value.trim();
|
| 1291 |
+
if (text.length < 20) {
|
| 1292 |
+
docsTestOutput.textContent = '{"error": "Text too short (minimum 20 characters)"}';
|
| 1293 |
+
docsTestOutput.classList.add('visible');
|
| 1294 |
+
return;
|
| 1295 |
+
}
|
| 1296 |
+
docsTestBtn.disabled = true;
|
| 1297 |
+
docsTestBtn.innerHTML = '<span class="loading"></span>';
|
| 1298 |
+
try {
|
| 1299 |
+
const res = await fetch(`${API_BASE}/v1/classify`, {
|
| 1300 |
+
method: 'POST',
|
| 1301 |
+
headers: { 'Content-Type': 'application/json' },
|
| 1302 |
+
body: JSON.stringify({ text, top_n: 5 })
|
| 1303 |
+
});
|
| 1304 |
+
const data = await res.json();
|
| 1305 |
+
docsTestOutput.textContent = JSON.stringify(data, null, 2);
|
| 1306 |
+
} catch (e) {
|
| 1307 |
+
docsTestOutput.textContent = `{"error": "${e.message}"}`;
|
| 1308 |
+
}
|
| 1309 |
+
docsTestOutput.classList.add('visible');
|
| 1310 |
+
docsTestBtn.disabled = false;
|
| 1311 |
+
docsTestBtn.textContent = 'Send Request';
|
| 1312 |
+
});
|
| 1313 |
+
}
|
| 1314 |
+
|
| 1315 |
+
// Hook provider list population into the existing load flow
|
| 1316 |
+
const _origLoadProviders = loadProviders;
|
| 1317 |
+
loadProviders = async function() {
|
| 1318 |
+
await _origLoadProviders();
|
| 1319 |
+
populateDocsProviders();
|
| 1320 |
+
};
|
| 1321 |
+
|
| 1322 |
checkStatus();
|
| 1323 |
</script>
|
| 1324 |
</body>
|
train.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIFinder Training Script
|
| 3 |
+
Loads data, trains a two-headed GPU classifier, reports metrics, and saves the model.
|
| 4 |
+
|
| 5 |
+
Usage: python3 train.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
import joblib
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 16 |
+
from sklearn.model_selection import train_test_split
|
| 17 |
+
from sklearn.metrics import classification_report
|
| 18 |
+
from sklearn.preprocessing import LabelEncoder
|
| 19 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 20 |
+
|
| 21 |
+
from config import (
|
| 22 |
+
MODEL_DIR,
|
| 23 |
+
TEST_SIZE,
|
| 24 |
+
RANDOM_STATE,
|
| 25 |
+
HIDDEN_DIM,
|
| 26 |
+
EMBED_DIM,
|
| 27 |
+
DROPOUT,
|
| 28 |
+
BATCH_SIZE,
|
| 29 |
+
EPOCHS,
|
| 30 |
+
LEARNING_RATE,
|
| 31 |
+
WEIGHT_DECAY,
|
| 32 |
+
EARLY_STOP_PATIENCE,
|
| 33 |
+
)
|
| 34 |
+
from data_loader import load_all_data
|
| 35 |
+
from features import FeaturePipeline
|
| 36 |
+
from model import AIFinderNet
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _log(msg, t0=None):
|
| 40 |
+
"""Print a timestamped log message, optionally with elapsed time."""
|
| 41 |
+
ts = time.strftime("%H:%M:%S")
|
| 42 |
+
if t0 is not None:
|
| 43 |
+
elapsed = time.time() - t0
|
| 44 |
+
print(f" [{ts}] {msg} ({elapsed:.1f}s)")
|
| 45 |
+
else:
|
| 46 |
+
print(f" [{ts}] {msg}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
t_start = time.time()
|
| 51 |
+
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
print("AIFinder Training - Provider Classification")
|
| 54 |
+
print("=" * 60)
|
| 55 |
+
|
| 56 |
+
# ── GPU check ──────────────────────────────────────────────
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
device = torch.device("cuda")
|
| 59 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 60 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 61 |
+
_log(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
|
| 62 |
+
else:
|
| 63 |
+
device = torch.device("cpu")
|
| 64 |
+
_log("No GPU available, using CPU")
|
| 65 |
+
|
| 66 |
+
# ── Load data ──────────────────────────────────────────────
|
| 67 |
+
_log("Starting data load...")
|
| 68 |
+
t0 = time.time()
|
| 69 |
+
texts, providers, models, _is_ai = load_all_data()
|
| 70 |
+
_log("Data load complete", t0)
|
| 71 |
+
|
| 72 |
+
if len(texts) < 100:
|
| 73 |
+
print("ERROR: Not enough data loaded. Check dataset access.")
|
| 74 |
+
sys.exit(1)
|
| 75 |
+
|
| 76 |
+
# ── Encode labels ──────────────────────────────────────────
|
| 77 |
+
_log("Encoding labels...")
|
| 78 |
+
t0 = time.time()
|
| 79 |
+
provider_enc = LabelEncoder()
|
| 80 |
+
provider_labels = provider_enc.fit_transform(providers)
|
| 81 |
+
num_providers = len(provider_enc.classes_)
|
| 82 |
+
_log(f"Labels encoded — {num_providers} providers", t0)
|
| 83 |
+
|
| 84 |
+
# ── Train/test split ───────────────────────────────────────
|
| 85 |
+
_log("Splitting train/test...")
|
| 86 |
+
t0 = time.time()
|
| 87 |
+
indices = np.arange(len(texts))
|
| 88 |
+
train_idx, test_idx = train_test_split(
|
| 89 |
+
indices,
|
| 90 |
+
test_size=TEST_SIZE,
|
| 91 |
+
random_state=RANDOM_STATE,
|
| 92 |
+
stratify=provider_labels,
|
| 93 |
+
)
|
| 94 |
+
train_texts = [texts[i] for i in train_idx]
|
| 95 |
+
test_texts = [texts[i] for i in test_idx]
|
| 96 |
+
_log(f"Split: {len(train_texts)} train / {len(test_texts)} test", t0)
|
| 97 |
+
|
| 98 |
+
# ── Build features ─────────────────────────────────────────
|
| 99 |
+
_log("Building feature pipeline (fit on train)...")
|
| 100 |
+
t0 = time.time()
|
| 101 |
+
pipeline = FeaturePipeline()
|
| 102 |
+
X_train = pipeline.fit_transform(train_texts)
|
| 103 |
+
_log(f"Train features: {X_train.shape}", t0)
|
| 104 |
+
|
| 105 |
+
_log("Transforming test set...")
|
| 106 |
+
t0 = time.time()
|
| 107 |
+
X_test = pipeline.transform(test_texts)
|
| 108 |
+
_log(f"Test features: {X_test.shape}", t0)
|
| 109 |
+
|
| 110 |
+
input_dim = X_train.shape[1]
|
| 111 |
+
|
| 112 |
+
# ── Move to device ─────────────────────────────────────────
|
| 113 |
+
_log(f"Moving data to {device}...")
|
| 114 |
+
t0 = time.time()
|
| 115 |
+
X_train_t = torch.tensor(X_train.toarray(), dtype=torch.float32).to(device)
|
| 116 |
+
X_test_t = torch.tensor(X_test.toarray(), dtype=torch.float32).to(device)
|
| 117 |
+
y_prov_train = torch.tensor(provider_labels[train_idx], dtype=torch.long).to(device)
|
| 118 |
+
y_prov_test = torch.tensor(provider_labels[test_idx], dtype=torch.long).to(device)
|
| 119 |
+
if device.type == "cuda":
|
| 120 |
+
mem_used = torch.cuda.memory_allocated() / 1024**3
|
| 121 |
+
_log(f"GPU memory used: {mem_used:.2f} GB", t0)
|
| 122 |
+
else:
|
| 123 |
+
_log(f"Data on {device}", t0)
|
| 124 |
+
|
| 125 |
+
# ── DataLoaders ────────────────────────────────────────────
|
| 126 |
+
batch_size = min(BATCH_SIZE, 512) if device.type == "cpu" else BATCH_SIZE
|
| 127 |
+
train_ds = TensorDataset(X_train_t, y_prov_train)
|
| 128 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
| 129 |
+
val_ds = TensorDataset(X_test_t, y_prov_test)
|
| 130 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
| 131 |
+
|
| 132 |
+
# ── Model ────���─────────────────────────────────────────────
|
| 133 |
+
_log("Building model...")
|
| 134 |
+
net = AIFinderNet(
|
| 135 |
+
input_dim=input_dim,
|
| 136 |
+
num_providers=num_providers,
|
| 137 |
+
hidden_dim=HIDDEN_DIM,
|
| 138 |
+
embed_dim=EMBED_DIM,
|
| 139 |
+
dropout=DROPOUT,
|
| 140 |
+
).to(device)
|
| 141 |
+
n_params = sum(p.numel() for p in net.parameters())
|
| 142 |
+
_log(f"Model: {n_params:,} parameters")
|
| 143 |
+
|
| 144 |
+
# ── Class-weighted loss ────────────────────────────────────
|
| 145 |
+
prov_weights = compute_class_weight(
|
| 146 |
+
"balanced", classes=np.arange(num_providers), y=provider_labels[train_idx]
|
| 147 |
+
)
|
| 148 |
+
prov_criterion = nn.CrossEntropyLoss(
|
| 149 |
+
weight=torch.tensor(prov_weights, dtype=torch.float32).to(device)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# ── Optimizer + scheduler ──────────────────────────────────
|
| 153 |
+
optimizer = torch.optim.AdamW(
|
| 154 |
+
net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
|
| 155 |
+
)
|
| 156 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 157 |
+
optimizer,
|
| 158 |
+
max_lr=LEARNING_RATE,
|
| 159 |
+
epochs=EPOCHS,
|
| 160 |
+
steps_per_epoch=len(train_loader),
|
| 161 |
+
)
|
| 162 |
+
use_amp = device.type == "cuda"
|
| 163 |
+
scaler = torch.amp.GradScaler() if use_amp else None
|
| 164 |
+
|
| 165 |
+
# ── Training loop ──────────────────────────────────────────
|
| 166 |
+
_log(
|
| 167 |
+
f"Training for {EPOCHS} epochs, batch_size={batch_size}, "
|
| 168 |
+
f"early_stop_patience={EARLY_STOP_PATIENCE}..."
|
| 169 |
+
)
|
| 170 |
+
t0 = time.time()
|
| 171 |
+
|
| 172 |
+
best_val_loss = float("inf")
|
| 173 |
+
best_state = None
|
| 174 |
+
patience_counter = 0
|
| 175 |
+
|
| 176 |
+
for epoch in range(EPOCHS):
|
| 177 |
+
# ── Train phase ───────────────────────────────────────
|
| 178 |
+
net.train()
|
| 179 |
+
epoch_loss = 0.0
|
| 180 |
+
n_batches = 0
|
| 181 |
+
|
| 182 |
+
for batch_X, batch_prov in train_loader:
|
| 183 |
+
optimizer.zero_grad(set_to_none=True)
|
| 184 |
+
if use_amp:
|
| 185 |
+
with torch.amp.autocast(device_type="cuda"):
|
| 186 |
+
prov_logits = net(batch_X)
|
| 187 |
+
loss = prov_criterion(prov_logits, batch_prov)
|
| 188 |
+
scaler.scale(loss).backward()
|
| 189 |
+
scaler.unscale_(optimizer)
|
| 190 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
| 191 |
+
scaler.step(optimizer)
|
| 192 |
+
scaler.update()
|
| 193 |
+
else:
|
| 194 |
+
prov_logits = net(batch_X)
|
| 195 |
+
loss = prov_criterion(prov_logits, batch_prov)
|
| 196 |
+
loss.backward()
|
| 197 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
| 198 |
+
optimizer.step()
|
| 199 |
+
scheduler.step()
|
| 200 |
+
epoch_loss += loss.item()
|
| 201 |
+
n_batches += 1
|
| 202 |
+
|
| 203 |
+
avg_train_loss = epoch_loss / n_batches
|
| 204 |
+
|
| 205 |
+
# ── Validation phase ──────────────────────────────────
|
| 206 |
+
net.eval()
|
| 207 |
+
val_loss = 0.0
|
| 208 |
+
val_batches = 0
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
for batch_X, batch_prov in val_loader:
|
| 211 |
+
prov_logits = net(batch_X)
|
| 212 |
+
loss = prov_criterion(prov_logits, batch_prov)
|
| 213 |
+
val_loss += loss.item()
|
| 214 |
+
val_batches += 1
|
| 215 |
+
avg_val_loss = val_loss / val_batches
|
| 216 |
+
|
| 217 |
+
# ── Early stopping check ──────────────────────────────
|
| 218 |
+
if avg_val_loss < best_val_loss:
|
| 219 |
+
best_val_loss = avg_val_loss
|
| 220 |
+
best_state = {k: v.clone() for k, v in net.state_dict().items()}
|
| 221 |
+
patience_counter = 0
|
| 222 |
+
else:
|
| 223 |
+
patience_counter += 1
|
| 224 |
+
|
| 225 |
+
# ── Logging ───────────────────────────────────────────
|
| 226 |
+
if (epoch + 1) % 5 == 0 or epoch == 0:
|
| 227 |
+
lr = scheduler.get_last_lr()[0]
|
| 228 |
+
marker = " *" if patience_counter == 0 else ""
|
| 229 |
+
_log(
|
| 230 |
+
f"Epoch {epoch + 1:>3d}/{EPOCHS} "
|
| 231 |
+
f"train={avg_train_loss:.4f} "
|
| 232 |
+
f"val={avg_val_loss:.4f} "
|
| 233 |
+
f"lr={lr:.2e}{marker}"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if patience_counter >= EARLY_STOP_PATIENCE:
|
| 237 |
+
_log(
|
| 238 |
+
f"Early stopping at epoch {epoch + 1} "
|
| 239 |
+
f"(best val_loss={best_val_loss:.4f})"
|
| 240 |
+
)
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
# Restore best weights
|
| 244 |
+
if best_state is not None:
|
| 245 |
+
net.load_state_dict(best_state)
|
| 246 |
+
_log(f"Restored best weights (val_loss={best_val_loss:.4f})")
|
| 247 |
+
|
| 248 |
+
_log("Training complete", t0)
|
| 249 |
+
|
| 250 |
+
# ── Evaluate ───────────────────────────────────────────────
|
| 251 |
+
_log("Evaluating...")
|
| 252 |
+
net.eval()
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
prov_logits = net(X_test_t)
|
| 255 |
+
|
| 256 |
+
prov_preds = prov_logits.argmax(dim=1).cpu().numpy()
|
| 257 |
+
prov_true = y_prov_test.cpu().numpy()
|
| 258 |
+
|
| 259 |
+
print("\n === Provider Classification ===")
|
| 260 |
+
print(
|
| 261 |
+
classification_report(
|
| 262 |
+
prov_true,
|
| 263 |
+
prov_preds,
|
| 264 |
+
target_names=provider_enc.classes_,
|
| 265 |
+
zero_division=0,
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# ── Save ───────────────────────────────────────────────────
|
| 270 |
+
_log(f"Saving to {MODEL_DIR}/ ...")
|
| 271 |
+
t0 = time.time()
|
| 272 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
checkpoint = {
|
| 275 |
+
"input_dim": input_dim,
|
| 276 |
+
"num_providers": num_providers,
|
| 277 |
+
"hidden_dim": HIDDEN_DIM,
|
| 278 |
+
"embed_dim": EMBED_DIM,
|
| 279 |
+
"dropout": DROPOUT,
|
| 280 |
+
"state_dict": net.state_dict(),
|
| 281 |
+
}
|
| 282 |
+
torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
|
| 283 |
+
_log(" Saved classifier.pt")
|
| 284 |
+
|
| 285 |
+
joblib.dump(pipeline, os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
|
| 286 |
+
_log(" Saved feature_pipeline.joblib")
|
| 287 |
+
joblib.dump(provider_enc, os.path.join(MODEL_DIR, "provider_enc.joblib"))
|
| 288 |
+
_log(" Saved provider_enc.joblib")
|
| 289 |
+
|
| 290 |
+
_log("All artifacts saved", t0)
|
| 291 |
+
|
| 292 |
+
elapsed = time.time() - t_start
|
| 293 |
+
if device.type == "cuda":
|
| 294 |
+
mem_peak = torch.cuda.max_memory_allocated() / 1024**3
|
| 295 |
+
print(f"\n{'=' * 60}")
|
| 296 |
+
print(f"Training complete in {elapsed:.1f}s (peak GPU mem: {mem_peak:.2f} GB)")
|
| 297 |
+
print(f"{'=' * 60}")
|
| 298 |
+
else:
|
| 299 |
+
print(f"\n{'=' * 60}")
|
| 300 |
+
print(f"Training complete in {elapsed:.1f}s")
|
| 301 |
+
print(f"{'=' * 60}")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
main()
|