File size: 9,283 Bytes
9cf599c
 
 
30b81fd
9cf599c
 
30b81fd
 
01d0daa
a023a85
 
9cf599c
a023a85
9cf599c
 
 
30b81fd
 
 
01d0daa
 
 
9cf599c
a023a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90d4f4d
 
 
 
 
a023a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461c5a6
a023a85
 
 
 
461c5a6
a023a85
 
 
 
 
 
 
461c5a6
a023a85
 
 
 
 
9cf599c
 
 
 
30b81fd
a023a85
9cf599c
 
30b81fd
 
01d0daa
 
 
 
 
9cf599c
 
 
 
 
 
 
30b81fd
 
 
461c5a6
a023a85
30b81fd
01d0daa
 
 
 
 
9cf599c
 
 
 
 
 
 
 
 
 
30b81fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d0daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf599c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Hugging Face Spaces API wrapper
Provides direct JSON responses without job queues
Integrated with feedback learning pipeline for continuous model improvement
"""
from flask import Flask, request, jsonify
from inference_core import run_pipeline_for_image, download_image_from_url, upload_to_cloudinary, model, device
from scripts.feedback_learning_pipeline import initialize_feedback_pipeline, run_feedback_training
from scripts.model_versioning import initialize_model_tracker
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime
import os
import atexit

app = Flask(__name__)

# Initialize feedback learning pipeline
feedback_pipeline = initialize_feedback_pipeline(model, device)

# Initialize model versioning tracker
model_tracker = initialize_model_tracker()


# ===== Automated Training Scheduler =====
def automated_training_check():
    """
    Background task that checks for new feedback and triggers training automatically
    Runs periodically to keep the model updated with user corrections
    """
    try:
        print(f"\n[Automated Training] Running scheduled check at {datetime.now()}")
        
        # Check if there's enough feedback to warrant training
        stats = feedback_pipeline.get_feedback_stats()
        
        if stats.get("ready_for_retraining", False):
            unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
            print(f"[Automated Training] Found {unprocessed} unprocessed feedback samples")
            print(f"[Automated Training] Starting training cycle...")
            
            # Trigger training
            results = run_feedback_training(feedback_pipeline)
            
            # Validate results is a dictionary
            if not isinstance(results, dict):
                print(f"[Automated Training] Error: Expected dict, got {type(results)}: {results}")
                return
            
            if results.get("status") == "success":
                print(f"[Automated Training] ✓ Training completed successfully")
                print(f"[Automated Training] Processed {results.get('corrections_processed')} corrections")
            else:
                print(f"[Automated Training] Training status: {results.get('status')}")
        else:
            unprocessed = stats.get("total_feedback_in_db", 0) - stats.get("total_processed", 0)
            print(f"[Automated Training] Not enough feedback for training ({unprocessed} new samples)")
            
    except Exception as e:
        print(f"[Automated Training] Error during automated check: {e}")


# Initialize background scheduler
scheduler = BackgroundScheduler(daemon=True)

# Schedule training checks every 1 day
# You can adjust the interval: hours, minutes, seconds
scheduler.add_job(
    func=automated_training_check,
    trigger="interval",
    days=1,  # Check every day
    id='automated_training',
    name='Automated Feedback Training Check',
    replace_existing=True
)

# Start the scheduler
scheduler.start()
print("[Automated Training] Scheduler started - checking for new feedback every 1 day")

# Shutdown scheduler gracefully when app exits
atexit.register(lambda: scheduler.shutdown())


@app.route("/", methods=["GET"])
def home():
    """Home page with API documentation"""
    return jsonify({
        "service": "Anomaly Detection API with Feedback Learning",
        "version": "2.1",
        "endpoints": {
            "/health": "GET - Health check",
            "/infer": "POST - Run inference on image URL",
            "/feedback/stats": "GET - Get feedback statistics and training status",
            "/feedback/train": "POST - Manually trigger feedback training cycle",
            "/model/current": "GET - Get current model version and parameters",
            "/model/versions": "GET - Get model version history",
            "/model/training-history": "GET - Get training cycle history",
            "/model/compare": "POST - Compare model versions"
        },
        "example_request": {
            "method": "POST",
            "url": "/infer",
            "body": {
                "image_url": "https://example.com/image.jpg"
            }
        },
        "feedback_info": {
            "description": "User corrections are automatically fetched from Supabase",
            "automated_training": "Checks for new feedback every 1 day and trains automatically",
            "training_threshold": "10+ new feedback samples triggers training",
            "manual_training": "POST /feedback/train to trigger immediately"
        },
        "versioning_info": {
            "description": "Model versions and training history tracked automatically",
            "view_current": "GET /model/current to see active model parameters",
            "view_history": "GET /model/versions to see all versions"
        }
    })


@app.route("/health", methods=["GET"])
def health():
    """Health check endpoint"""
    return jsonify({"status": "healthy"}), 200


@app.route("/feedback/stats", methods=["GET"])
def feedback_stats():
    """
    Get feedback statistics and training status
    """
    try:
        stats = feedback_pipeline.get_feedback_stats()
        return jsonify(stats), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/feedback/train", methods=["POST"])
def trigger_training():
    """
    Manually trigger a feedback training cycle
    Fetches user corrections from Supabase and improves model
    """
    try:
        results = run_feedback_training(feedback_pipeline)
        return jsonify(results), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/model/current", methods=["GET"])
def get_current_model():
    """
    Get current active model version and parameters
    """
    try:
        current_state = model_tracker.get_current_model_state()
        return jsonify(current_state), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/model/versions", methods=["GET"])
def get_model_versions():
    """
    Get model version history
    Query params: limit (default: 20)
    """
    try:
        limit = int(request.args.get('limit', 20))
        versions = model_tracker.get_version_history(limit=limit)
        return jsonify({
            "total": len(versions),
            "versions": versions
        }), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/model/training-history", methods=["GET"])
def get_training_history():
    """
    Get training cycle history
    Query params: limit (default: 20)
    """
    try:
        limit = int(request.args.get('limit', 20))
        history = model_tracker.get_training_history(limit=limit)
        return jsonify({
            "total": len(history),
            "training_cycles": history
        }), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/model/compare", methods=["POST"])
def compare_versions():
    """
    Compare multiple model versions
    Request JSON: {"version_ids": ["id1", "id2", ...]}
    """
    try:
        data = request.get_json()
        if not data or "version_ids" not in data:
            return jsonify({"error": "Missing version_ids"}), 400
        
        version_ids = data["version_ids"]
        comparison = model_tracker.generate_comparison_table(version_ids)
        
        return jsonify({
            "comparison": comparison,
            "version_count": len(version_ids)
        }), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/infer", methods=["POST"])
def infer():
    """
    Inference endpoint - returns direct JSON response
    Request JSON: {"image_url": "https://..."}
    """
    try:
        data = request.get_json()
        if not data or "image_url" not in data:
            return jsonify({"error": "Missing image_url"}), 400
        
        image_url = data["image_url"]
        
        # Download image
        local_path = download_image_from_url(image_url)
        
        # Run pipeline
        results = run_pipeline_for_image(local_path)
        
        # Upload outputs
        boxed_url = upload_to_cloudinary(results["boxed_path"], folder="pipeline_outputs") if results["boxed_path"] else None
        mask_url = upload_to_cloudinary(results["mask_path"], folder="pipeline_outputs") if results["mask_path"] else None
        filtered_url = upload_to_cloudinary(results["filtered_path"], folder="pipeline_outputs") if results["filtered_path"] else None
        
        # Clean up
        if os.path.exists(local_path):
            os.remove(local_path)
        
        # Direct JSON response (no job queue wrapper)
        return jsonify({
            "label": results["label"],
            "boxed_url": boxed_url,
            "mask_url": mask_url,
            "filtered_url": filtered_url,
            "boxes": results.get("boxes", [])
        })
        
    except Exception as e:
        return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
    # For Hugging Face Spaces, use port 7860
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)