Senum2001
Change automated training interval from 2 minutes to 1 day
461c5a6
"""
Hugging Face Spaces API wrapper
Provides direct JSON responses without job queues
Integrated with feedback learning pipeline for continuous model improvement
"""
from flask import Flask, request, jsonify
from inference_core import run_pipeline_for_image, download_image_from_url, upload_to_cloudinary, model, device
from scripts.feedback_learning_pipeline import initialize_feedback_pipeline, run_feedback_training
from scripts.model_versioning import initialize_model_tracker
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime
import os
import atexit
app = Flask(__name__)
# Initialize feedback learning pipeline
feedback_pipeline = initialize_feedback_pipeline(model, device)
# Initialize model versioning tracker
model_tracker = initialize_model_tracker()
# ===== Automated Training Scheduler =====
def automated_training_check():
"""
Background task that checks for new feedback and triggers training automatically
Runs periodically to keep the model updated with user corrections
"""
try:
print(f"\n[Automated Training] Running scheduled check at {datetime.now()}")
# Check if there's enough feedback to warrant training
stats = feedback_pipeline.get_feedback_stats()
if stats.get("ready_for_retraining", False):
unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
print(f"[Automated Training] Found {unprocessed} unprocessed feedback samples")
print(f"[Automated Training] Starting training cycle...")
# Trigger training
results = run_feedback_training(feedback_pipeline)
# Validate results is a dictionary
if not isinstance(results, dict):
print(f"[Automated Training] Error: Expected dict, got {type(results)}: {results}")
return
if results.get("status") == "success":
print(f"[Automated Training] ✓ Training completed successfully")
print(f"[Automated Training] Processed {results.get('corrections_processed')} corrections")
else:
print(f"[Automated Training] Training status: {results.get('status')}")
else:
unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
print(f"[Automated Training] Not enough feedback for training ({unprocessed} new samples)")
except Exception as e:
print(f"[Automated Training] Error during automated check: {e}")
# Initialize background scheduler
scheduler = BackgroundScheduler(daemon=True)
# Schedule training checks every 1 day
# You can adjust the interval: hours, minutes, seconds
scheduler.add_job(
func=automated_training_check,
trigger="interval",
days=1, # Check every day
id='automated_training',
name='Automated Feedback Training Check',
replace_existing=True
)
# Start the scheduler
scheduler.start()
print("[Automated Training] Scheduler started - checking for new feedback every 1 day")
# Shutdown scheduler gracefully when app exits
atexit.register(lambda: scheduler.shutdown())
@app.route("/", methods=["GET"])
def home():
"""Home page with API documentation"""
return jsonify({
"service": "Anomaly Detection API with Feedback Learning",
"version": "2.1",
"endpoints": {
"/health": "GET - Health check",
"/infer": "POST - Run inference on image URL",
"/feedback/stats": "GET - Get feedback statistics and training status",
"/feedback/train": "POST - Manually trigger feedback training cycle",
"/model/current": "GET - Get current model version and parameters",
"/model/versions": "GET - Get model version history",
"/model/training-history": "GET - Get training cycle history",
"/model/compare": "POST - Compare model versions"
},
"example_request": {
"method": "POST",
"url": "/infer",
"body": {
"image_url": "https://example.com/image.jpg"
}
},
"feedback_info": {
"description": "User corrections are automatically fetched from Supabase",
"automated_training": "Checks for new feedback every 1 day and trains automatically",
"training_threshold": "10+ new feedback samples triggers training",
"manual_training": "POST /feedback/train to trigger immediately"
},
"versioning_info": {
"description": "Model versions and training history tracked automatically",
"view_current": "GET /model/current to see active model parameters",
"view_history": "GET /model/versions to see all versions"
}
})
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
return jsonify({"status": "healthy"}), 200
@app.route("/feedback/stats", methods=["GET"])
def feedback_stats():
"""
Get feedback statistics and training status
"""
try:
stats = feedback_pipeline.get_feedback_stats()
return jsonify(stats), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/feedback/train", methods=["POST"])
def trigger_training():
"""
Manually trigger a feedback training cycle
Fetches user corrections from Supabase and improves model
"""
try:
results = run_feedback_training(feedback_pipeline)
return jsonify(results), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/current", methods=["GET"])
def get_current_model():
"""
Get current active model version and parameters
"""
try:
current_state = model_tracker.get_current_model_state()
return jsonify(current_state), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/versions", methods=["GET"])
def get_model_versions():
"""
Get model version history
Query params: limit (default: 20)
"""
try:
limit = int(request.args.get('limit', 20))
versions = model_tracker.get_version_history(limit=limit)
return jsonify({
"total": len(versions),
"versions": versions
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/training-history", methods=["GET"])
def get_training_history():
"""
Get training cycle history
Query params: limit (default: 20)
"""
try:
limit = int(request.args.get('limit', 20))
history = model_tracker.get_training_history(limit=limit)
return jsonify({
"total": len(history),
"training_cycles": history
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/model/compare", methods=["POST"])
def compare_versions():
"""
Compare multiple model versions
Request JSON: {"version_ids": ["id1", "id2", ...]}
"""
try:
data = request.get_json()
if not data or "version_ids" not in data:
return jsonify({"error": "Missing version_ids"}), 400
version_ids = data["version_ids"]
comparison = model_tracker.generate_comparison_table(version_ids)
return jsonify({
"comparison": comparison,
"version_count": len(version_ids)
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/infer", methods=["POST"])
def infer():
"""
Inference endpoint - returns direct JSON response
Request JSON: {"image_url": "https://..."}
"""
try:
data = request.get_json()
if not data or "image_url" not in data:
return jsonify({"error": "Missing image_url"}), 400
image_url = data["image_url"]
# Download image
local_path = download_image_from_url(image_url)
# Run pipeline
results = run_pipeline_for_image(local_path)
# Upload outputs
boxed_url = upload_to_cloudinary(results["boxed_path"], folder="pipeline_outputs") if results["boxed_path"] else None
mask_url = upload_to_cloudinary(results["mask_path"], folder="pipeline_outputs") if results["mask_path"] else None
filtered_url = upload_to_cloudinary(results["filtered_path"], folder="pipeline_outputs") if results["filtered_path"] else None
# Clean up
if os.path.exists(local_path):
os.remove(local_path)
# Direct JSON response (no job queue wrapper)
return jsonify({
"label": results["label"],
"boxed_url": boxed_url,
"mask_url": mask_url,
"filtered_url": filtered_url,
"boxes": results.get("boxes", [])
})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# For Hugging Face Spaces, use port 7860
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)