""" Feedback Learning Pipeline Fetches user-corrected annotations from Supabase and uses them to improve model accuracy """ import os import json import time from datetime import datetime from typing import List, Dict, Any, Optional import numpy as np import torch from supabase import create_client, Client from PIL import Image import tempfile # Import model versioning system try: from scripts.model_versioning import ModelVersionTracker MODEL_VERSIONING_AVAILABLE = True except ImportError: MODEL_VERSIONING_AVAILABLE = False print("[Warning] Model versioning not available") # Supabase configuration SUPABASE_URL = os.getenv("SUPABASE_URL", "https://xbcgrpqiibicestnhytt.supabase.co") SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InhiY2dycHFpaWJpY2VzdG5oeXR0Iiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc1NTkxMzk3MywiZXhwIjoyMDcxNDg5OTczfQ.sANBuVZ6gdYc5kHkxTXZ67jtE9QHPw5HFaUKffP1Jrs") # Initialize Supabase client supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) # Training state tracking TRAINING_STATE_FILE = "feedback_training_state.json" MIN_FEEDBACK_FOR_RETRAINING = 10 # Minimum feedback samples to trigger retraining class FeedbackLearningPipeline: """Pipeline for continuous learning from user feedback""" def __init__(self, model, device): """ Initialize the feedback learning pipeline Args: model: The PatchCore model instance device: PyTorch device (cpu or cuda) """ self.model = model self.device = device self.training_state = self._load_training_state() # Initialize model versioning tracker if MODEL_VERSIONING_AVAILABLE: self.version_tracker = ModelVersionTracker() else: self.version_tracker = None def _load_training_state(self) -> Dict[str, Any]: """Load training state from disk""" if os.path.exists(TRAINING_STATE_FILE): with open(TRAINING_STATE_FILE, 'r') as f: return json.load(f) return { "last_training_time": None, "last_processed_feedback_id": None, "total_feedback_processed": 0, "training_runs": [] } def _save_training_state(self): """Save training state to disk""" with open(TRAINING_STATE_FILE, 'w') as f: json.dump(self.training_state, f, indent=2) def fetch_new_feedback(self, limit: int = 100) -> List[Dict[str, Any]]: """ Fetch new feedback logs from Supabase Args: limit: Maximum number of feedback records to fetch Returns: List of feedback log dictionaries """ try: query = supabase.table('feedback_logs').select('*') # Only fetch feedback newer than last processed if self.training_state.get("last_processed_feedback_id"): query = query.gt('created_at', self.training_state["last_processed_feedback_id"]) response = query.order('created_at', desc=False).limit(limit).execute() feedback_logs = response.data if response.data else [] print(f"[Feedback Pipeline] Fetched {len(feedback_logs)} new feedback records") return feedback_logs except Exception as e: print(f"[Feedback Pipeline] Error fetching feedback: {e}") return [] def extract_corrected_annotations(self, feedback_logs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Extract user-corrected annotations from feedback logs Args: feedback_logs: List of feedback log dictionaries Returns: List of correction samples with image_id, model_pred, and user_correction """ corrections = [] for log in feedback_logs: try: # Parse JSON strings if needed model_predicted = log.get("model_predicted_anomalies", {}) user_corrected = log.get("final_accepted_annotations", {}) annotator_metadata = log.get("annotator_metadata", {}) # If they're strings, parse them as JSON if isinstance(model_predicted, str): model_predicted = json.loads(model_predicted) if model_predicted else {} if isinstance(user_corrected, str): user_corrected = json.loads(user_corrected) if user_corrected else {} if isinstance(annotator_metadata, str): annotator_metadata = json.loads(annotator_metadata) if annotator_metadata else {} correction = { "feedback_id": log.get("id"), "image_id": log.get("image_id"), "model_predicted": model_predicted, "user_corrected": user_corrected, "annotator_metadata": annotator_metadata, "created_at": log.get("created_at") } # Validate that we have both predictions and corrections if correction["model_predicted"] and correction["user_corrected"]: corrections.append(correction) except Exception as e: print(f"[Feedback Pipeline] Error processing feedback log {log.get('id')}: {e}") continue print(f"[Feedback Pipeline] Extracted {len(corrections)} valid corrections") return corrections def calculate_correction_patterns(self, corrections: List[Dict[str, Any]]) -> Dict[str, Any]: """ Analyze patterns in user corrections to understand model weaknesses Args: corrections: List of correction samples Returns: Dictionary with analysis results """ patterns = { "total_corrections": len(corrections), "bbox_adjustments": [], "label_changes": [], "severity_changes": [], "false_positives": 0, "false_negatives": 0, "timestamp": datetime.now().isoformat() } for correction in corrections: model_pred = correction["model_predicted"] user_corr = correction["user_corrected"] # Analyze label changes model_label = model_pred.get("label", "") user_label = user_corr.get("label", "") if model_label != user_label: patterns["label_changes"].append({ "from": model_label, "to": user_label, "image_id": correction["image_id"] }) # Count false positives (model detected anomaly, user said normal) if "Critical" in model_label and "Normal" in user_label: patterns["false_positives"] += 1 # Count false negatives (model said normal, user found anomaly) if "Normal" in model_label and "Critical" in user_label: patterns["false_negatives"] += 1 # Analyze bounding box adjustments model_detections = model_pred.get("detections", []) user_detections = user_corr.get("detections", []) if len(model_detections) != len(user_detections): patterns["bbox_adjustments"].append({ "image_id": correction["image_id"], "model_count": len(model_detections), "user_count": len(user_detections) }) print(f"[Feedback Pipeline] Pattern Analysis:") print(f" - Label Changes: {len(patterns['label_changes'])}") print(f" - False Positives: {patterns['false_positives']}") print(f" - False Negatives: {patterns['false_negatives']}") print(f" - BBox Adjustments: {len(patterns['bbox_adjustments'])}") return patterns def apply_model_adjustments(self, patterns: Dict[str, Any]): """ Apply learned patterns to improve model inference This function adjusts model confidence thresholds and parameters based on user feedback patterns Args: patterns: Pattern analysis results """ try: # Calculate adjustment factors based on false positive/negative rates total = patterns["total_corrections"] if total == 0: return fp_rate = patterns["false_positives"] / total fn_rate = patterns["false_negatives"] / total print(f"[Feedback Pipeline] Model Adjustment:") print(f" - False Positive Rate: {fp_rate:.2%}") print(f" - False Negative Rate: {fn_rate:.2%}") # Store adjustment metadata for inference adjustment_data = { "fp_rate": fp_rate, "fn_rate": fn_rate, "total_corrections": total, "timestamp": datetime.now().isoformat(), "recommendation": self._get_threshold_recommendation(fp_rate, fn_rate) } # Save adjustment data with open("model_adjustments.json", "w") as f: json.dump(adjustment_data, f, indent=2) print(f"[Feedback Pipeline] Recommendation: {adjustment_data['recommendation']}") except Exception as e: print(f"[Feedback Pipeline] Error applying adjustments: {e}") def _get_threshold_recommendation(self, fp_rate: float, fn_rate: float) -> str: """Generate threshold adjustment recommendation""" if fp_rate > 0.3: return "INCREASE detection threshold - too many false positives" elif fn_rate > 0.3: return "DECREASE detection threshold - too many false negatives" else: return "Current threshold is balanced" def should_retrain(self) -> bool: """ Determine if model should be retrained based on feedback count Returns: True if retraining should be triggered """ new_feedback_count = len(self.fetch_new_feedback(limit=1000)) if new_feedback_count >= MIN_FEEDBACK_FOR_RETRAINING: print(f"[Feedback Pipeline] Retraining threshold met: {new_feedback_count} new feedback samples") return True print(f"[Feedback Pipeline] Not enough feedback for retraining: {new_feedback_count}/{MIN_FEEDBACK_FOR_RETRAINING}") return False def run_training_cycle(self): """ Execute a full training cycle: fetch feedback, analyze patterns, apply adjustments """ try: print(f"\n[Feedback Pipeline] Starting training cycle at {datetime.now()}") # Capture model state BEFORE training before_state = None if self.version_tracker: try: before_state = self.version_tracker.get_current_model_state() self.version_tracker.log_model_version(before_state) if before_state and isinstance(before_state, dict): print(f"[Model Versioning] Captured state before training: {before_state.get('version_id', 'unknown')[:8]}...") except Exception as e: print(f"[Model Versioning] Warning: Could not capture before state: {e}") before_state = None # Reset to None on error self.version_tracker = None # Disable versioning for this run # Fetch new feedback feedback_logs = self.fetch_new_feedback(limit=1000) if not feedback_logs: print("[Feedback Pipeline] No new feedback available") return { "status": "no_feedback", "message": "No new feedback to process" } # Extract corrections corrections = self.extract_corrected_annotations(feedback_logs) if not corrections: print("[Feedback Pipeline] No valid corrections found") return { "status": "no_corrections", "message": "No valid corrections extracted" } # Analyze patterns patterns = self.calculate_correction_patterns(corrections) # Apply adjustments self.apply_model_adjustments(patterns) # Update training state self.training_state["last_training_time"] = datetime.now().isoformat() self.training_state["last_processed_feedback_id"] = feedback_logs[-1].get("created_at") self.training_state["total_feedback_processed"] += len(corrections) self.training_state["training_runs"].append({ "timestamp": datetime.now().isoformat(), "corrections_processed": len(corrections), "patterns": patterns }) self._save_training_state() # Capture model state AFTER training after_state = None training_cycle_id = None if self.version_tracker and before_state: try: after_state = self.version_tracker.get_current_model_state() self.version_tracker.log_model_version(after_state) if after_state and isinstance(after_state, dict): print(f"[Model Versioning] Captured state after training: {after_state.get('version_id', 'unknown')[:8]}...") # Only log training cycle if both states are valid dictionaries if (before_state and isinstance(before_state, dict) and after_state and isinstance(after_state, dict)): # Log the training cycle with before/after comparison training_cycle_id = self.version_tracker.log_training_cycle( before_state=before_state, after_state=after_state, feedback_count=len(corrections), patterns=patterns, performance_metrics=None # TODO: Calculate actual metrics ) if training_cycle_id: print(f"[Training History] Logged training cycle: {training_cycle_id[:8]}...") else: print(f"[Training History] Skipping cycle logging - invalid state data") except Exception as e: print(f"[Model Versioning] Error logging version: {e}") print(f"[Feedback Pipeline] Training cycle completed successfully") print(f"[Feedback Pipeline] Total feedback processed: {self.training_state['total_feedback_processed']}") # Helper to safely extract version_id whether state is a dict, str, or None def _extract_version_id(state): if not state: return None if isinstance(state, dict): return state.get("version_id") if isinstance(state, str): return state return None return { "status": "success", "corrections_processed": len(corrections), "patterns": patterns, "total_feedback_processed": self.training_state["total_feedback_processed"], "before_version_id": _extract_version_id(before_state), "after_version_id": _extract_version_id(after_state), "training_cycle_id": training_cycle_id } except Exception as e: print(f"[Feedback Pipeline] CRITICAL ERROR in run_training_cycle: {e}") print(f"[Feedback Pipeline] Error type: {type(e).__name__}") import traceback traceback.print_exc() return { "status": "error", "message": str(e), "error_type": type(e).__name__ } def get_feedback_stats(self) -> Dict[str, Any]: """Get statistics about feedback and training""" try: # Get total feedback count from Supabase response = supabase.table('feedback_logs').select('id', count='exact').execute() total_feedback = response.count if response.count else 0 return { "total_feedback_in_db": total_feedback, "total_processed": self.training_state.get("total_feedback_processed", 0), "last_training_time": self.training_state.get("last_training_time"), "training_runs": len(self.training_state.get("training_runs", [])), "ready_for_retraining": total_feedback >= MIN_FEEDBACK_FOR_RETRAINING } except Exception as e: print(f"[Feedback Pipeline] Error getting stats: {e}") return { "error": str(e) } def initialize_feedback_pipeline(model, device): """ Initialize the feedback learning pipeline Args: model: PatchCore model instance device: PyTorch device Returns: FeedbackLearningPipeline instance """ return FeedbackLearningPipeline(model, device) def run_feedback_training(pipeline: FeedbackLearningPipeline): """ Run a single training cycle Args: pipeline: FeedbackLearningPipeline instance Returns: Training results dictionary """ return pipeline.run_training_cycle()