Spaces:
Sleeping
Sleeping
File size: 9,283 Bytes
9cf599c 30b81fd 9cf599c 30b81fd 01d0daa a023a85 9cf599c a023a85 9cf599c 30b81fd 01d0daa 9cf599c a023a85 90d4f4d a023a85 461c5a6 a023a85 461c5a6 a023a85 461c5a6 a023a85 9cf599c 30b81fd a023a85 9cf599c 30b81fd 01d0daa 9cf599c 30b81fd 461c5a6 a023a85 30b81fd 01d0daa 9cf599c 30b81fd 01d0daa 9cf599c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | """
Hugging Face Spaces API wrapper
Provides direct JSON responses without job queues
Integrated with feedback learning pipeline for continuous model improvement
"""
from flask import Flask, request, jsonify
from inference_core import run_pipeline_for_image, download_image_from_url, upload_to_cloudinary, model, device
from scripts.feedback_learning_pipeline import initialize_feedback_pipeline, run_feedback_training
from scripts.model_versioning import initialize_model_tracker
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime
import os
import atexit
app = Flask(__name__)
# Initialize feedback learning pipeline
feedback_pipeline = initialize_feedback_pipeline(model, device)
# Initialize model versioning tracker
model_tracker = initialize_model_tracker()
# ===== Automated Training Scheduler =====
def automated_training_check():
"""
Background task that checks for new feedback and triggers training automatically
Runs periodically to keep the model updated with user corrections
"""
try:
print(f"\n[Automated Training] Running scheduled check at {datetime.now()}")
# Check if there's enough feedback to warrant training
stats = feedback_pipeline.get_feedback_stats()
if stats.get("ready_for_retraining", False):
unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
print(f"[Automated Training] Found {unprocessed} unprocessed feedback samples")
print(f"[Automated Training] Starting training cycle...")
# Trigger training
results = run_feedback_training(feedback_pipeline)
# Validate results is a dictionary
if not isinstance(results, dict):
print(f"[Automated Training] Error: Expected dict, got {type(results)}: {results}")
return
if results.get("status") == "success":
print(f"[Automated Training] ✓ Training completed successfully")
print(f"[Automated Training] Processed {results.get('corrections_processed')} corrections")
else:
print(f"[Automated Training] Training status: {results.get('status')}")
else:
unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
print(f"[Automated Training] Not enough feedback for training ({unprocessed} new samples)")
except Exception as e:
print(f"[Automated Training] Error during automated check: {e}")
# Initialize background scheduler
scheduler = BackgroundScheduler(daemon=True)
# Schedule training checks every 1 day
# You can adjust the interval: hours, minutes, seconds
scheduler.add_job(
func=automated_training_check,
trigger="interval",
days=1, # Check every day
id='automated_training',
name='Automated Feedback Training Check',
replace_existing=True
)
# Start the scheduler
scheduler.start()
print("[Automated Training] Scheduler started - checking for new feedback every 1 day")
# Shutdown scheduler gracefully when app exits
atexit.register(lambda: scheduler.shutdown())
@app.route("/", methods=["GET"])
def home():
"""Home page with API documentation"""
return jsonify({
"service": "Anomaly Detection API with Feedback Learning",
"version": "2.1",
"endpoints": {
"/health": "GET - Health check",
"/infer": "POST - Run inference on image URL",
"/feedback/stats": "GET - Get feedback statistics and training status",
"/feedback/train": "POST - Manually trigger feedback training cycle",
"/model/current": "GET - Get current model version and parameters",
"/model/versions": "GET - Get model version history",
"/model/training-history": "GET - Get training cycle history",
"/model/compare": "POST - Compare model versions"
},
"example_request": {
"method": "POST",
"url": "/infer",
"body": {
"image_url": "https://example.com/image.jpg"
}
},
"feedback_info": {
"description": "User corrections are automatically fetched from Supabase",
"automated_training": "Checks for new feedback every 1 day and trains automatically",
"training_threshold": "10+ new feedback samples triggers training",
"manual_training": "POST /feedback/train to trigger immediately"
},
"versioning_info": {
"description": "Model versions and training history tracked automatically",
"view_current": "GET /model/current to see active model parameters",
"view_history": "GET /model/versions to see all versions"
}
})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
return jsonify({"status": "healthy"}), 200
@app.route("/feedback/stats", methods=["GET"])
def feedback_stats():
"""
Get feedback statistics and training status
"""
try:
stats = feedback_pipeline.get_feedback_stats()
return jsonify(stats), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/feedback/train", methods=["POST"])
def trigger_training():
"""
Manually trigger a feedback training cycle
Fetches user corrections from Supabase and improves model
"""
try:
results = run_feedback_training(feedback_pipeline)
return jsonify(results), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/current", methods=["GET"])
def get_current_model():
"""
Get current active model version and parameters
"""
try:
current_state = model_tracker.get_current_model_state()
return jsonify(current_state), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/versions", methods=["GET"])
def get_model_versions():
"""
Get model version history
Query params: limit (default: 20)
"""
try:
limit = int(request.args.get('limit', 20))
versions = model_tracker.get_version_history(limit=limit)
return jsonify({
"total": len(versions),
"versions": versions
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/training-history", methods=["GET"])
def get_training_history():
"""
Get training cycle history
Query params: limit (default: 20)
"""
try:
limit = int(request.args.get('limit', 20))
history = model_tracker.get_training_history(limit=limit)
return jsonify({
"total": len(history),
"training_cycles": history
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/compare", methods=["POST"])
def compare_versions():
"""
Compare multiple model versions
Request JSON: {"version_ids": ["id1", "id2", ...]}
"""
try:
data = request.get_json()
if not data or "version_ids" not in data:
return jsonify({"error": "Missing version_ids"}), 400
version_ids = data["version_ids"]
comparison = model_tracker.generate_comparison_table(version_ids)
return jsonify({
"comparison": comparison,
"version_count": len(version_ids)
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/infer", methods=["POST"])
def infer():
"""
Inference endpoint - returns direct JSON response
Request JSON: {"image_url": "https://..."}
"""
try:
data = request.get_json()
if not data or "image_url" not in data:
return jsonify({"error": "Missing image_url"}), 400
image_url = data["image_url"]
# Download image
local_path = download_image_from_url(image_url)
# Run pipeline
results = run_pipeline_for_image(local_path)
# Upload outputs
boxed_url = upload_to_cloudinary(results["boxed_path"], folder="pipeline_outputs") if results["boxed_path"] else None
mask_url = upload_to_cloudinary(results["mask_path"], folder="pipeline_outputs") if results["mask_path"] else None
filtered_url = upload_to_cloudinary(results["filtered_path"], folder="pipeline_outputs") if results["filtered_path"] else None
# Clean up
if os.path.exists(local_path):
os.remove(local_path)
# Direct JSON response (no job queue wrapper)
return jsonify({
"label": results["label"],
"boxed_url": boxed_url,
"mask_url": mask_url,
"filtered_url": filtered_url,
"boxes": results.get("boxes", [])
})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# For Hugging Face Spaces, use port 7860
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)
|