Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify, send_file | |
| import os | |
| import sqlite3 | |
| from datetime import datetime | |
| import requests | |
| import secrets | |
| import sys | |
| import subprocess | |
| # Performance optimization: Only enable DEBUG prints in development | |
| # On Hugging Face Spaces, disable debug for better performance | |
| DEBUG_MODE = os.environ.get("FLASK_DEBUG", "false").lower() == "true" | |
| if os.environ.get("SPACE_ID"): | |
| DEBUG_MODE = False # Always disable debug on Hugging Face Spaces | |
| def debug_print(*args, **kwargs): | |
| """Only print if DEBUG_MODE is enabled""" | |
| if DEBUG_MODE: | |
| print(*args, **kwargs) | |
| # Add material detection path (import functions lazily to avoid blocking startup) | |
| sys.path.append('apps.err/material_detection_naturalobjects') | |
| # Import functions only when needed, not at module level | |
| app = Flask(__name__) | |
| app.secret_key = os.environ.get("SECRET_KEY", secrets.token_hex(32)) | |
| # Base directory for path resolution | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Models and datasets are siblings to spad_for_vision_space (same level) | |
| # Locally: ../models and ../datasets relative to BASE_DIR | |
| # On Hugging Face Spaces: models/ and datasets/ will be downloaded to BASE_DIR/models and BASE_DIR/datasets | |
| def get_models_dir(): | |
| """Get models directory path - sibling to BASE_DIR locally, or BASE_DIR/models on Spaces""" | |
| if os.environ.get("SPACE_ID"): | |
| # On Hugging Face Spaces, download to BASE_DIR/models | |
| return os.path.join(BASE_DIR, "models") | |
| else: | |
| # Local: models is sibling to spad_for_vision_space | |
| return os.path.join(os.path.dirname(BASE_DIR), "models") | |
| def get_datasets_dir(): | |
| """Get datasets directory path - sibling to BASE_DIR locally, or BASE_DIR/datasets on Spaces""" | |
| if os.environ.get("SPACE_ID"): | |
| # On Hugging Face Spaces, download to BASE_DIR/datasets | |
| return os.path.join(BASE_DIR, "datasets") | |
| else: | |
| # Local: datasets is sibling to spad_for_vision_space | |
| return os.path.join(os.path.dirname(BASE_DIR), "datasets") | |
| DB_PATH = "visitors.db" | |
| def init_db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS visitors ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| ip_address TEXT, | |
| country TEXT, | |
| city TEXT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| # Initialize database on startup (with error handling) | |
| try: | |
| init_db() | |
| except Exception as e: | |
| print(f"Warning: Database initialization failed (non-critical): {e}") | |
| # Continue anyway - database will be created on first use | |
| def get_visitor_location(ip_address: str): | |
| try: | |
| resp = requests.get(f"http://ipapi.co/{ip_address}/json/", timeout=5) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| return { | |
| "country": data.get("country_name", "Unknown"), | |
| "city": data.get("city", "Unknown"), | |
| "region": data.get("region", "Unknown"), | |
| } | |
| except Exception: | |
| pass | |
| return {"country": "Unknown", "city": "Unknown", "region": "Unknown"} | |
| def log_visitor(ip_address: str): | |
| location = get_visitor_location(ip_address) | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """ | |
| INSERT INTO visitors (ip_address, country, city) | |
| VALUES (?, ?, ?) | |
| """, | |
| (ip_address, location["country"], location["city"]), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def get_visitor_stats(): | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM visitors") | |
| total_visitors = cursor.fetchone()[0] | |
| cursor.execute("SELECT COUNT(DISTINCT ip_address) FROM visitors") | |
| unique_visitors = cursor.fetchone()[0] | |
| cursor.execute( | |
| """ | |
| SELECT country, COUNT(*) AS count | |
| FROM visitors | |
| GROUP BY country | |
| ORDER BY count DESC | |
| LIMIT 10 | |
| """ | |
| ) | |
| countries = cursor.fetchall() | |
| conn.close() | |
| return { | |
| "total_visitors": total_visitors, | |
| "unique_visitors": unique_visitors, | |
| "countries": countries, | |
| } | |
| def favicon(): | |
| favicon_path = os.path.join(app.static_folder, "favicon.ico") | |
| if os.path.exists(favicon_path): | |
| return send_file(favicon_path, mimetype="image/vnd.microsoft.icon") | |
| return "", 404 | |
| def home(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("index.html", stats=stats) | |
| def custom_yolov8_demo(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("custom_yolov8_demo.html", stats=stats) | |
| def spatiotemporal_detection(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("spatiotemporal_detection.html", stats=stats) | |
| def flat_surface_detection(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("flat_surface_detection.html", stats=stats) | |
| def material_purity(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("material_purity.html", stats=stats) | |
| def material_detection_head(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("material_detection_head.html", stats=stats) | |
| def fluid_purity_demo(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("fluid_purity_demo.html", stats=stats) | |
| def demos(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("demos.html", stats=stats) | |
| def articles(): | |
| """Display list of articles""" | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| # List available articles | |
| articles_dir = os.path.join(BASE_DIR, "articles") | |
| articles_list = [] | |
| if os.path.exists(articles_dir): | |
| for filename in os.listdir(articles_dir): | |
| if filename.endswith('.md'): | |
| article_name = filename.replace('.md', '').replace('_', ' ').title() | |
| articles_list.append({ | |
| 'filename': filename, | |
| 'name': article_name, | |
| 'url': f"/articles/{filename.replace('.md', '')}" | |
| }) | |
| return render_template("articles.html", articles=articles_list, stats=stats) | |
| def article_detail(article_name): | |
| """Display a specific article""" | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| # Read markdown file | |
| articles_dir = os.path.join(BASE_DIR, "articles") | |
| article_path = os.path.join(articles_dir, f"{article_name}.md") | |
| if not os.path.exists(article_path): | |
| return render_template("article_not_found.html", article_name=article_name, stats=stats), 404 | |
| # Read and parse markdown | |
| try: | |
| import markdown | |
| with open(article_path, 'r', encoding='utf-8') as f: | |
| markdown_content = f.read() | |
| # Convert markdown to HTML | |
| html_content = markdown.markdown(markdown_content, extensions=['extra', 'codehilite']) | |
| except ImportError: | |
| # Fallback if markdown not available | |
| with open(article_path, 'r', encoding='utf-8') as f: | |
| markdown_content = f.read() | |
| html_content = f"<pre>{markdown_content}</pre>" | |
| # Extract title from first line if it's a heading | |
| title = article_name.replace('_', ' ').title() | |
| if markdown_content.startswith('#'): | |
| first_line = markdown_content.split('\n')[0] | |
| title = first_line.replace('#', '').strip() | |
| return render_template("article_detail.html", | |
| title=title, | |
| content=html_content, | |
| markdown_content=markdown_content, | |
| stats=stats) | |
| def dinov3_demo(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("dinov3_demo.html", stats=stats) | |
| def detect_yolov3(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("yolov3_demo.html", stats=stats) | |
| def deliverables(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("deliverables.html", stats=stats, deliverables=[]) | |
| def create_deliverable(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| if request.method == "POST": | |
| return jsonify({"success": True, "message": "Deliverable creation not yet implemented"}) | |
| return render_template("create_deliverable.html", stats=stats, projects=[]) | |
| def deliverable_detail(deliverable_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("deliverable_detail.html", stats=stats, deliverable={"id": deliverable_id}) | |
| def edit_deliverable(deliverable_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| if request.method == "POST": | |
| return jsonify({"success": True, "message": "Deliverable editing not yet implemented"}) | |
| return render_template("edit_deliverable.html", stats=stats, deliverable={"id": deliverable_id}, projects=[]) | |
| def delete_deliverable(deliverable_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return jsonify({"success": True, "message": "Deliverable deletion not yet implemented"}) | |
| def toggle_deliverable_status(deliverable_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return jsonify({"success": True, "message": "Status toggle not yet implemented"}) | |
| def deliverables_in_progress(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| # Return a simple page - can be enhanced later | |
| return render_template("in_progress_deliverables.html", stats=stats, deliverables=[]) | |
| def deliverables_completed(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| # Return a simple page - can be enhanced later | |
| return render_template("completed_deliverables.html", stats=stats, deliverables=[]) | |
| def version_control(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| # Create a simple version control page | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Version Control - MV+</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; padding: 40px; text-align: center; } | |
| h1 { color: #B91D30; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Version Control</h1> | |
| <p>Version control features coming soon.</p> | |
| <p><a href="/">Return to Home</a></p> | |
| </body> | |
| </html> | |
| """ | |
| # Projects route removed - only Demos is accessible via dropdown menu | |
| def create_project(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| if request.method == "POST": | |
| # Handle project creation (placeholder) | |
| return jsonify({"success": True, "message": "Project creation not yet implemented"}) | |
| # For GET, return a simple message page | |
| return f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Create Project - MV+</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; padding: 40px; text-align: center; }} | |
| h1 {{ color: #B91D30; }} | |
| a {{ color: #00CED1; text-decoration: none; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Create Project</h1> | |
| <p>Project creation feature coming soon.</p> | |
| <p><a href="/projects">← Back to Projects</a></p> | |
| </body> | |
| </html> | |
| """ | |
| def project_detail(project_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Project Details - MV+</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; padding: 40px; text-align: center; }} | |
| h1 {{ color: #B91D30; }} | |
| a {{ color: #00CED1; text-decoration: none; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Project Details</h1> | |
| <p>Project ID: {project_id}</p> | |
| <p>Project details feature coming soon.</p> | |
| <p><a href="/projects">← Back to Projects</a></p> | |
| </body> | |
| </html> | |
| """ | |
| def edit_project(project_id): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| if request.method == "POST": | |
| # Handle project edit (placeholder) | |
| return jsonify({"success": True, "message": "Project editing not yet implemented"}) | |
| return f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Edit Project - MV+</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; padding: 40px; text-align: center; }} | |
| h1 {{ color: #B91D30; }} | |
| a {{ color: #00CED1; text-decoration: none; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Edit Project</h1> | |
| <p>Project ID: {project_id}</p> | |
| <p>Project editing feature coming soon.</p> | |
| <p><a href="/projects">← Back to Projects</a></p> | |
| </body> | |
| </html> | |
| """ | |
| def about(): | |
| visitor_ip = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr) | |
| log_visitor(visitor_ip) | |
| stats = get_visitor_stats() | |
| return render_template("about.html", stats=stats) | |
| def api_detect_spatiotemporal(): | |
| """ | |
| API endpoint for spatiotemporal detection using STO files. | |
| Wraps process_sto_file from material_detection_naturalobjects/spatiotemporal_detection_script.py | |
| """ | |
| try: | |
| # Import the new spatiotemporal detection script | |
| import sys | |
| sys.path.append('apps.err/material_detection_naturalobjects') | |
| from spatiotemporal_detection_script import process_sto_file as process_sto_file_spatiotemporal | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded"}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "No file selected"}), 400 | |
| filename_lower = file.filename.lower() | |
| if not filename_lower.endswith(".sto"): | |
| return jsonify({ | |
| "error": "Please upload a STO spatiotemporal file.", | |
| "error_type": "wrong_format", | |
| }), 400 | |
| temp_path = f"temp_spatiotemporal_{datetime.now().strftime('%Y%m%d_%H%M%S')}.sto" | |
| file.save(temp_path) | |
| try: | |
| result = process_sto_file_spatiotemporal(temp_path) | |
| if result.get("success"): | |
| return jsonify({ | |
| "success": True, | |
| "image": result.get("image"), | |
| "image_size": result.get("image_size"), | |
| "image_mode": result.get("image_mode"), | |
| "sto_length": result.get("sto_length"), | |
| "filename": file.filename, | |
| }) | |
| else: | |
| return jsonify({ | |
| "success": False, | |
| "error": result.get("error", "Unknown error"), | |
| }), 400 | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Spatiotemporal API Error: {e}") | |
| print(error_details) | |
| return jsonify({ | |
| "error": str(e), | |
| "error_type": "server_error", | |
| "details": error_details, | |
| }), 500 | |
| def api_stats(): | |
| stats = get_visitor_stats() | |
| return jsonify(stats) | |
| def fetch_weights_from_hub(repo_id, subfolder=None): | |
| """Fetch list of weight files from Hugging Face Hub repository | |
| Args: | |
| repo_id: Repository ID (e.g., 'mvplus/dinov3') | |
| subfolder: Optional subfolder path (None for individual repos) | |
| Returns: | |
| List of file paths relative to repo root | |
| """ | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| # List all files in the repository | |
| repo_files = api.list_repo_files(repo_id=repo_id, repo_type="model") | |
| # Filter for weight files (.pth, .pt, .ckpt, etc.) | |
| weight_extensions = {'.pth', '.pt', '.ckpt', '.h5', '.pb', '.onnx'} | |
| weight_files = [] | |
| for file_path in repo_files: | |
| # If subfolder specified, filter by subfolder | |
| if subfolder: | |
| if file_path.startswith(f"{subfolder}/") or file_path == subfolder: | |
| # Remove subfolder prefix for return value | |
| rel_path = file_path[len(subfolder)+1:] if file_path.startswith(f"{subfolder}/") else file_path | |
| if any(rel_path.lower().endswith(ext) for ext in weight_extensions): | |
| weight_files.append(file_path) # Keep full path for hub:// format | |
| else: | |
| # No subfolder - check if it's a weight file in root | |
| if any(file_path.lower().endswith(ext) for ext in weight_extensions): | |
| weight_files.append(file_path) | |
| return weight_files | |
| except Exception as e: | |
| print(f"Error fetching weights from Hub {repo_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def api_flat_surface_detection_weights(): | |
| """API endpoint to get flat surface detection model weights from local directory or Hugging Face Hub""" | |
| import os | |
| import glob | |
| import re | |
| try: | |
| weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "flat_surface") | |
| repo_id = MODEL_REPO_MAP.get('flat_surface', 'mvplus/flat_surface') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Extract accuracy and epoch from filename | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| # Use Hub path format for on-demand download | |
| weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", # Hub path for on-demand download | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type, | |
| "source": "hub", | |
| "repo_url": f"https://huggingface.co/{repo_id}", | |
| "repo_id": repo_id | |
| }) | |
| # Also check local files | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pth")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Extract accuracy and epoch from filename | |
| # Pattern: epoch_133_Accuracy_98.81517028808594__2025-11-04 23:05:50.452709.pth | |
| # Or: epoch_150_Accuracy_99.605.pth | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| # Remove trailing dots and convert to float | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| # If conversion fails, use defaults | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| elif filename == "1.pth": | |
| # Special case for 1.pth | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = "Model Weight (1.pth)" | |
| weight_type = "Default" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| weights.append({ | |
| "filename": filename, | |
| "path": weight_file, # Use local path | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type | |
| }) | |
| # Sort by accuracy (highest first), then by epoch | |
| weights.sort(key=lambda x: (x['accuracy'], x['epoch']), reverse=True) | |
| else: | |
| print(f"DEBUG: Weights directory not found: {weights_dir}") | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': weights | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading flat surface weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_fluid_purity_weights(): | |
| """API endpoint to get fluid purity model weights from local directory or Hugging Face Hub""" | |
| import os | |
| import glob | |
| import re | |
| try: | |
| weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "material_purity") | |
| repo_id = MODEL_REPO_MAP.get('material_purity', 'mvplus/material_purity') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Extract accuracy and epoch from filename | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type, | |
| "source": "hub" | |
| }) | |
| # Also check local files | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pth")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Extract accuracy and epoch from filename | |
| # Pattern: epoch_105_Accuracy_100.0__2025-11-06 09:02:10.335104.pth | |
| # Or: epoch_194_Accuracy_100.pth | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| # Remove trailing dots and convert to float | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| # If conversion fails, use defaults | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| # Check if this weight is already in the list from Hub | |
| existing = next((w for w in weights if w['filename'] == filename), None) | |
| if not existing: | |
| weights.append({ | |
| "filename": filename, | |
| "path": weight_file, # Use local path | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type, | |
| "source": "local" | |
| }) | |
| # Sort by accuracy (highest first), then by epoch | |
| weights.sort(key=lambda x: (x['accuracy'], x['epoch']), reverse=True) | |
| else: | |
| print(f"DEBUG: Weights directory not found: {weights_dir}") | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': weights | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading fluid purity weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_material_detection_head_weights(): | |
| """API endpoint to get material detection head model weights from local directory or Hugging Face Hub""" | |
| import os | |
| import glob | |
| import re | |
| try: | |
| weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "material_detection_head") | |
| repo_id = MODEL_REPO_MAP.get('material_detection_head', 'mvplus/material_detection_head') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Extract accuracy and epoch from filename | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type, | |
| "source": "hub" | |
| }) | |
| # Also check local files | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pth")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Extract accuracy and epoch from filename | |
| # Pattern: epoch_186_Accuracy_80.pth | |
| # Or: epoch_399_Accuracy_98.25.pth | |
| pattern_match = re.search(r'epoch_(\d+)_Accuracy_([\d\.]+)', filename) | |
| if pattern_match: | |
| epoch = int(pattern_match.group(1)) | |
| try: | |
| # Remove trailing dots and convert to float | |
| acc_str = pattern_match.group(2).rstrip('.') | |
| accuracy = float(acc_str) | |
| display_name = f"Epoch {epoch} ({accuracy:.2f}% accuracy)" | |
| weight_type = "Checkpoint" | |
| except ValueError: | |
| # If conversion fails, use defaults | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| else: | |
| epoch = 0 | |
| accuracy = 0.0 | |
| display_name = filename.replace('.pth', '') | |
| weight_type = "Checkpoint" | |
| # Check if this weight is already in the list from Hub | |
| existing = next((w for w in weights if w['filename'] == filename), None) | |
| if not existing: | |
| weights.append({ | |
| "filename": filename, | |
| "path": weight_file, # Use local path | |
| "display_name": display_name, | |
| "accuracy": accuracy, | |
| "epoch": epoch, | |
| "weight_type": weight_type, | |
| "source": "local" | |
| }) | |
| # Sort by accuracy (highest first), then by epoch | |
| weights.sort(key=lambda x: (x['accuracy'], x['epoch']), reverse=True) | |
| else: | |
| print(f"DEBUG: Weights directory not found: {weights_dir}") | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': weights | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading material detection head weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_get_model_info(): | |
| """API endpoint to get model architecture and classes based on weight path""" | |
| weight_path = request.form.get('weight_path', '') | |
| if not weight_path: | |
| return jsonify({'error': 'No weight path provided'}), 400 | |
| # Determine model type and class names based on weight path | |
| if 'material_purity' in weight_path or 'fluid_purity' in weight_path or 'purity' in weight_path.lower(): | |
| class_names = ['Impure', 'Pure'] | |
| num_classes = 2 | |
| model_architecture = 'Material Purity Classifier (Binary)' | |
| elif 'flat_surface' in weight_path.lower() or 'flatsurface' in weight_path.lower(): | |
| class_names = ['BCB', 'BNT', 'WGF', 'WNT'] | |
| num_classes = 4 | |
| model_architecture = 'Flat Surface Detection ConvNet' | |
| else: | |
| # Default fallback | |
| class_names = [] | |
| num_classes = 0 | |
| model_architecture = 'Unknown' | |
| classes_str = ', '.join(str(c) for c in class_names) if class_names else '-' | |
| return jsonify({ | |
| 'success': True, | |
| 'architecture': model_architecture, | |
| 'classes': class_names if isinstance(class_names, list) else list(class_names), | |
| 'classes_display': classes_str, | |
| 'num_classes': num_classes | |
| }) | |
| def api_detect_material_head(): | |
| """API endpoint for material detection head - handles both flat_surface and material_purity""" | |
| import time | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| import io | |
| import pickle | |
| from io import BytesIO | |
| from huggingface_hub import hf_hub_download | |
| # Set deterministic behavior | |
| torch.manual_seed(42) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| np.random.seed(42) | |
| start_time = time.time() | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Get selected weight path (can be local path or Hugging Face Hub path) | |
| weight_path = request.form.get('weight_path', '') | |
| if not weight_path: | |
| return jsonify({'error': 'No model weight selected'}), 400 | |
| debug_print(f"DEBUG: Received weight_path: '{weight_path}'") | |
| debug_print(f"DEBUG: weight_path is absolute: {os.path.isabs(weight_path) if weight_path else False}") | |
| debug_print(f"DEBUG: weight_path exists: {os.path.exists(weight_path) if weight_path else False}") | |
| # Determine model type from weight path | |
| is_material_purity = ('material_purity' in weight_path.lower() or 'fluid_purity' in weight_path.lower() or 'purity' in weight_path.lower()) | |
| is_flat_surface = ('flat_surface' in weight_path.lower() or 'flatsurface' in weight_path.lower()) | |
| is_material_detection_head = ('material_detection_head' in weight_path.lower() or 'naturalobjects' in weight_path.lower()) | |
| debug_print(f"DEBUG: Model type detection:") | |
| debug_print(f"DEBUG: is_material_purity: {is_material_purity}") | |
| debug_print(f"DEBUG: is_flat_surface: {is_flat_surface}") | |
| debug_print(f"DEBUG: is_material_detection_head: {is_material_detection_head}") | |
| debug_print(f"DEBUG: weight_path: {weight_path}") | |
| # Check if weight_path is a local file path or Hugging Face Hub path | |
| local_model_path = None | |
| # First, check if it's an absolute local path | |
| if os.path.isabs(weight_path) and os.path.exists(weight_path): | |
| # Absolute local file path | |
| print(f"DEBUG: Using absolute local weight file: {weight_path}") | |
| local_model_path = weight_path | |
| # Check if it exists as-is (relative path) | |
| elif os.path.exists(weight_path): | |
| print(f"DEBUG: Using local weight file (found as-is): {weight_path}") | |
| local_model_path = weight_path | |
| # Try relative to BASE_DIR | |
| else: | |
| relative_path = os.path.join(BASE_DIR, weight_path) | |
| if os.path.exists(relative_path): | |
| print(f"DEBUG: Using relative local weight file: {relative_path}") | |
| local_model_path = relative_path | |
| # If not found as local path, check if it's a Hugging Face Hub path | |
| if local_model_path is None: | |
| # Check if it's a Hub path (hub://repo_id/path/to/file) | |
| if weight_path.startswith('hub://'): | |
| downloaded_path = download_model_from_hub(weight_path) | |
| if downloaded_path: | |
| local_model_path = downloaded_path | |
| else: | |
| return jsonify({'error': f'Failed to download model from Hub: {weight_path}'}), 400 | |
| else: | |
| return jsonify({'error': f'Weight path not found: {weight_path}. Please check the path is correct.'}), 400 | |
| # Verify the local model file exists | |
| if not local_model_path or not os.path.exists(local_model_path): | |
| return jsonify({'error': f'Model file not found: {local_model_path or weight_path}'}), 400 | |
| # Check if file is STO or regular image | |
| filename_lower = file.filename.lower() | |
| is_sto_file = filename_lower.endswith('.sto') | |
| # Load image | |
| print(f"DEBUG: File type check - is_sto_file: {is_sto_file}, filename: {file.filename}") | |
| if is_sto_file: | |
| # Handle STO file - extract index 1 (16x16 material detection image) | |
| # STO structure: Index 0=metadata, Index 1=16x16 material image, Index 2=OD metadata, Index 3=640x640 OD image | |
| temp_sto_path = f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.sto" | |
| with open(temp_sto_path, 'wb') as temp_file: | |
| file.stream.seek(0) | |
| temp_file.write(file.stream.read()) | |
| try: | |
| with open(temp_sto_path, 'rb') as f: | |
| sto_data = pickle.load(f) | |
| if len(sto_data) < 2: | |
| return jsonify({'error': 'Invalid STO file - need at least 2 items'}), 400 | |
| # Extract index 1 for material detection (16x16 image) | |
| sto_item = sto_data[1] | |
| if isinstance(sto_item, bytes): | |
| image = Image.open(BytesIO(sto_item)).convert('RGB') | |
| print(f"DEBUG: Extracted image from STO index 1 (bytes), size: {image.size}, mode: {image.mode}") | |
| elif hasattr(sto_item, 'mode'): | |
| image = sto_item.convert('RGB') | |
| print(f"DEBUG: Extracted image from STO index 1 (PIL), size: {image.size}, mode: {image.mode}") | |
| # Additional debug: Check if image is mostly black/empty (might indicate wrong index) | |
| import numpy as np | |
| img_array = np.array(image) | |
| if img_array.mean() < 5 and img_array.std() < 5: | |
| print(f"⚠️ WARNING: Image at index 1 appears to be mostly black (mean={img_array.mean():.2f}, std={img_array.std():.2f})") | |
| print(f"⚠️ This might indicate the wrong image is being extracted from the STO file") | |
| else: | |
| return jsonify({'error': f'Invalid STO file structure at index 1: expected image, got {type(sto_item).__name__}'}), 400 | |
| finally: | |
| if os.path.exists(temp_sto_path): | |
| os.remove(temp_sto_path) | |
| else: | |
| # Regular image file - use material_detection_functions for processing | |
| file.stream.seek(0) | |
| # Use the working predict_material function from material_detection_functions | |
| try: | |
| # Lazy import to avoid blocking startup | |
| from material_detection_functions import process_png_bytes | |
| # Process the image using the working function | |
| processed_image, _ = process_png_bytes(file.stream.read()) | |
| image = processed_image | |
| print(f"DEBUG: Processed image using material_detection_functions, size: {image.size}, mode: {image.mode}") | |
| # Debug: Check image pixel values | |
| import numpy as np | |
| img_array = np.array(image) | |
| print(f"DEBUG: Image array shape: {img_array.shape}") | |
| print(f"DEBUG: Image array min/max: {img_array.min()}/{img_array.max()}") | |
| print(f"DEBUG: Image array mean: {img_array.mean():.2f}, std: {img_array.std():.2f}") | |
| print(f"DEBUG: Image array sample (first 3x3 pixels): {img_array[:3, :3] if len(img_array.shape) == 2 else img_array[:3, :3, :]}") | |
| except Exception as e: | |
| print(f"DEBUG: Error using process_png_bytes, falling back to direct loading: {e}") | |
| file.stream.seek(0) | |
| image = Image.open(file.stream).convert('RGB') | |
| print(f"DEBUG: Loaded regular image directly, size: {image.size}, mode: {image.mode}") | |
| # Preprocess image based on model type | |
| # CRITICAL: Material detection head uses ToTensor + Normalize(0.5, 0.5, 0.5) as in eval script | |
| # For material_detection_head, we can optionally use preprocess_image from material_detection_functions | |
| # Lazy import to avoid blocking startup | |
| try: | |
| from material_detection_functions import preprocess_image, predict_material | |
| except ImportError as e: | |
| print(f"Warning: Could not import material_detection_functions: {e}") | |
| preprocess_image = None | |
| predict_material = None | |
| # but it uses ImageNet normalization, so we'll stick with the eval script normalization | |
| if is_material_purity or is_flat_surface or is_material_detection_head: | |
| # All use: ToTensor + Normalize(0.5, 0.5, 0.5) - EXACTLY as in eval script | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), # Converts PIL to tensor, normalizes to [0,1], converts to CHW | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Maps [0,1] to [-1,1] | |
| ]) | |
| # Ensure image is RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize to 16x16 if not already (EXACTLY as in eval script) | |
| if image.size != (16, 16): | |
| image = image.resize((16, 16), Image.Resampling.LANCZOS) | |
| # Apply transform - returns [C, H, W] tensor | |
| print(f"DEBUG: Before transform - image size: {image.size}, mode: {image.mode}") | |
| # CRITICAL: Check if image has actual pixel variation | |
| import numpy as np | |
| img_array_before = np.array(image) | |
| print(f"DEBUG: Image array before transform - shape: {img_array_before.shape}, min: {img_array_before.min()}, max: {img_array_before.max()}, mean: {img_array_before.mean():.2f}, std: {img_array_before.std():.2f}") | |
| if img_array_before.std() < 1.0: | |
| print(f"⚠️ WARNING: Image has very low std ({img_array_before.std():.2f}), might be nearly uniform!") | |
| image_tensor = transform(image) | |
| print(f"DEBUG: After transform - tensor shape: {image_tensor.shape}, dtype: {image_tensor.dtype}") | |
| print(f"DEBUG: After transform - tensor range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]") | |
| print(f"DEBUG: After transform - tensor mean: {image_tensor.mean():.4f}, std: {image_tensor.std():.4f}") | |
| # CRITICAL: Check if tensor is nearly uniform (would cause all predictions to be similar) | |
| if image_tensor.std() < 0.01: | |
| print(f"⚠️ WARNING: Tensor has very low std ({image_tensor.std():.4f}) after normalization, might cause uniform predictions!") | |
| # CRITICAL: Ensure tensor is exactly [C, H, W] before adding batch dimension | |
| if len(image_tensor.shape) == 3: | |
| # Add batch dimension: [C, H, W] -> [1, C, H, W] | |
| image_tensor = image_tensor.unsqueeze(0) | |
| print(f"DEBUG: Added batch dimension, shape: {image_tensor.shape}") | |
| elif len(image_tensor.shape) == 4: | |
| # Already has batch dimension, but verify it's [1, C, H, W] | |
| if image_tensor.shape[0] != 1: | |
| # Remove extra batch dimensions and re-add | |
| while len(image_tensor.shape) > 3: | |
| image_tensor = image_tensor.squeeze(0) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| print(f"DEBUG: Already had batch dimension, shape: {image_tensor.shape}") | |
| else: | |
| raise ValueError(f"Unexpected image tensor shape after transform: {image_tensor.shape}") | |
| # Final verification: tensor must be [1, 3, 16, 16] | |
| if image_tensor.shape != torch.Size([1, 3, 16, 16]): | |
| print(f"WARNING: Image tensor shape {image_tensor.shape} != expected [1, 3, 16, 16]") | |
| # Try to fix it | |
| while len(image_tensor.shape) > 4: | |
| image_tensor = image_tensor.squeeze(0) | |
| if len(image_tensor.shape) == 3: | |
| image_tensor = image_tensor.unsqueeze(0) | |
| if image_tensor.shape != torch.Size([1, 3, 16, 16]): | |
| raise ValueError(f"Cannot fix image tensor shape: {image_tensor.shape}, expected [1, 3, 16, 16]") | |
| print(f"DEBUG: Final tensor shape verified: {image_tensor.shape}") | |
| else: | |
| # Default preprocessing | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image = image.resize((16, 16), Image.Resampling.LANCZOS) | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Define model architectures | |
| class MaterialPurityClassifier(nn.Module): | |
| """Binary CNN classifier for material purity""" | |
| def __init__(self): | |
| super(MaterialPurityClassifier, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(16) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.pool1 = nn.MaxPool2d(kernel_size=2) | |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(32) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.pool2 = nn.MaxPool2d(kernel_size=2) | |
| self.fc1 = nn.Linear(32 * 4 * 4, 64) | |
| self.relu3 = nn.ReLU(inplace=True) | |
| self.fc2 = nn.Linear(64, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu1(x) | |
| x = self.pool1(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu2(x) | |
| x = self.pool2(x) | |
| x = x.view(-1, 32 * 4 * 4) | |
| x = self.fc1(x) | |
| x = self.relu3(x) | |
| x = self.fc2(x) | |
| x = self.sigmoid(x) | |
| return x | |
| class ConvNetFlatSurface(nn.Module): | |
| """ConvNet for flat surface detection - 4 classes (matches eval script)""" | |
| def __init__(self): | |
| super(ConvNetFlatSurface, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=1, padding=1) | |
| self.act1 = nn.ReLU() | |
| self.drop1 = nn.Dropout(0.3) | |
| self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=1, padding=1) | |
| self.act2 = nn.ReLU() | |
| self.pool2 = nn.MaxPool2d(kernel_size=(2,2)) | |
| self.flat = nn.Flatten() | |
| self.fc3 = nn.Linear(2048, 512) # 32 * 8 * 8 = 2048 after pooling | |
| self.act3 = nn.ReLU() | |
| self.drop3 = nn.Dropout(0.5) | |
| self.fc4 = nn.Linear(512, 4) # 4 classes: BCB, BNT, WGF, WNT | |
| def forward(self, x): | |
| x = self.act1(self.conv1(x)) | |
| x = self.drop1(x) | |
| x = self.act2(self.conv2(x)) | |
| x = self.pool2(x) | |
| x = self.flat(x) | |
| x = self.act3(self.fc3(x)) | |
| x = self.drop3(x) | |
| x = self.fc4(x) | |
| return x | |
| class ConvNetMaterialDetectionHead(nn.Module): | |
| """ConvNet for material detection head - 12 classes (matches eval script)""" | |
| def __init__(self): | |
| super(ConvNetMaterialDetectionHead, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=1, padding=1) | |
| self.act1 = nn.ReLU() | |
| self.drop1 = nn.Dropout(0.3) | |
| self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=1, padding=1) | |
| self.act2 = nn.ReLU() | |
| self.pool2 = nn.MaxPool2d(kernel_size=(2,2)) | |
| self.flat = nn.Flatten() | |
| self.fc3 = nn.Linear(2048, 512) # 32 * 8 * 8 = 2048 after pooling from 16x16 | |
| self.act3 = nn.ReLU() | |
| self.drop3 = nn.Dropout(0.5) | |
| self.fc4 = nn.Linear(512, 12) # 12 classes | |
| def forward(self, x): | |
| # EXACTLY as in eval script - dropout is defined but NOT called in forward | |
| # The eval script shows: #==x=self.drop1(x) (commented out) | |
| # Dropout layers exist but are not used during inference | |
| x = self.act1(self.conv1(x)) | |
| x = self.act2(self.conv2(x)) | |
| x = self.pool2(x) | |
| x = self.flat(x) | |
| x = self.act3(self.fc3(x)) | |
| x = self.fc4(x) | |
| return x | |
| # Load checkpoint | |
| checkpoint = torch.load(local_model_path, map_location='cpu') | |
| # Initialize model based on type | |
| if is_material_purity: | |
| model = MaterialPurityClassifier() | |
| class_names = ['Impure', 'Pure'] | |
| model_architecture = 'Material Purity Classifier (Binary)' | |
| elif is_flat_surface: | |
| model = ConvNetFlatSurface() | |
| class_names = ['BCB', 'BNT', 'WGF', 'WNT'] | |
| model_architecture = 'Flat Surface Detection ConvNet' | |
| elif is_material_detection_head: | |
| # Use the working approach: use predict_material from material_detection_functions | |
| # Lazy import to avoid blocking startup | |
| from material_detection_functions import predict_material, preprocess_image | |
| # But we need to load the ConvNetMaterialDetectionHead model first | |
| model = ConvNetMaterialDetectionHead() | |
| # Class names in alphabetical order (as ImageFolder sorts them during training) | |
| # IMPORTANT: ImageFolder assigns class IDs based on alphabetical order of folder names | |
| # This is the order that the model was trained with, matching multiwebapp | |
| # Multiwebapp explicitly states: "Class order MUST match ImageFolder's alphabetical assignment during training" | |
| class_names = [ | |
| '3dmodel', # 0 (alphabetically first) | |
| 'LEDscreen', # 1 (lowercase 's' to match training directory) | |
| 'bowl__purpleplastic', # 2 | |
| 'bowl__whiteceramic', # 3 | |
| 'carrot__natural', # 4 | |
| 'eggplant__natural', # 5 | |
| 'greenpepper__natural', # 6 | |
| 'potato__natural', # 7 | |
| 'redpepper__natural', # 8 | |
| 'teacup__ceramic', # 9 | |
| 'tomato__natural', # 10 | |
| 'yellowpepper__natural' # 11 | |
| ] | |
| model_architecture = 'Material Detection Head ConvNet (12 classes)' | |
| # CRITICAL: Use preprocess_image from material_detection_functions for correct preprocessing | |
| # Lazy import to avoid blocking startup | |
| from material_detection_functions import preprocess_image | |
| # But we need to override it to use the correct normalization for ConvNetMaterialDetectionHead | |
| print(f"DEBUG: Using ConvNetMaterialDetectionHead with material_detection_functions preprocessing") | |
| else: | |
| return jsonify({'error': f'Unknown model type. Weight path: {weight_path}'}), 400 | |
| # Load weights - handle different checkpoint formats | |
| print(f"DEBUG: Loading model from checkpoint: {local_model_path}") | |
| print(f"DEBUG: Model type: {model_architecture}") | |
| print(f"DEBUG: Expected num_classes: {len(class_names)}") | |
| if isinstance(checkpoint, dict): | |
| if 'msd' in checkpoint: | |
| # Our trained models use 'msd' key | |
| print(f"DEBUG: Loading from 'msd' key") | |
| model.load_state_dict(checkpoint['msd'], strict=True) | |
| print(f"DEBUG: Model loaded successfully from 'msd'") | |
| elif 'state_dict' in checkpoint: | |
| print(f"DEBUG: Loading from 'state_dict' key") | |
| model.load_state_dict(checkpoint['state_dict'], strict=True) | |
| print(f"DEBUG: Model loaded successfully from 'state_dict'") | |
| elif 'model_state_dict' in checkpoint: | |
| print(f"DEBUG: Loading from 'model_state_dict' key") | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=True) | |
| print(f"DEBUG: Model loaded successfully from 'model_state_dict'") | |
| else: | |
| # Try loading as state_dict directly | |
| try: | |
| print(f"DEBUG: Trying to load checkpoint dict directly as state_dict") | |
| model.load_state_dict(checkpoint, strict=True) | |
| print(f"DEBUG: Model loaded successfully from checkpoint dict") | |
| except Exception as e: | |
| print(f"Warning: Could not load as state_dict: {e}") | |
| # If checkpoint is the model itself (unlikely but possible) | |
| if hasattr(checkpoint, 'forward'): | |
| model = checkpoint | |
| print(f"DEBUG: Using checkpoint as model directly") | |
| else: | |
| raise ValueError(f"Unknown checkpoint format: {type(checkpoint)}") | |
| else: | |
| # Checkpoint might be the model itself (unlikely for our models) | |
| if hasattr(checkpoint, 'forward'): | |
| model = checkpoint | |
| print(f"DEBUG: Using checkpoint as model directly") | |
| else: | |
| # Try to load as state_dict | |
| try: | |
| print(f"DEBUG: Trying to load checkpoint as state_dict") | |
| model.load_state_dict(checkpoint, strict=True) | |
| print(f"DEBUG: Model loaded successfully from checkpoint") | |
| except Exception as e: | |
| print(f"DEBUG: Failed to load checkpoint: {e}") | |
| raise ValueError(f"Unknown checkpoint format: {type(checkpoint)}") | |
| # Verify model output shape | |
| if hasattr(model, 'fc4'): | |
| print(f"DEBUG: Model fc4 output features: {model.fc4.out_features}") | |
| if model.fc4.out_features != len(class_names): | |
| print(f"WARNING: Model output features ({model.fc4.out_features}) != num classes ({len(class_names)})") | |
| # CRITICAL: Set model to eval mode BEFORE moving to CPU | |
| # This ensures dropout layers are disabled | |
| model.eval() | |
| torch.set_grad_enabled(False) | |
| # CRITICAL: Ensure model and tensor are on CPU (matching multiwebapp) | |
| model = model.cpu() | |
| image_tensor = image_tensor.cpu() | |
| # CRITICAL: Double-check model is in eval mode after moving to CPU | |
| model.eval() | |
| # Verify dropout is disabled | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Dropout): | |
| if module.training: | |
| print(f"WARNING: Dropout layer {name} is still in training mode!") | |
| module.eval() | |
| # Debug: Print input tensor info | |
| print(f"DEBUG: ========== FINAL TENSOR SHAPE CHECK ==========") | |
| print(f"DEBUG: Input tensor shape: {image_tensor.shape}") | |
| print(f"DEBUG: Input tensor ndim: {len(image_tensor.shape)}") | |
| print(f"DEBUG: Expected shape: [1, 3, 16, 16]") | |
| if image_tensor.shape != torch.Size([1, 3, 16, 16]): | |
| print(f"DEBUG: WARNING - Tensor shape {image_tensor.shape} != expected [1, 3, 16, 16]") | |
| print(f"DEBUG: Input tensor range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]") | |
| print(f"DEBUG: Input tensor mean: {image_tensor.mean():.4f}, std: {image_tensor.std():.4f}") | |
| print(f"DEBUG: Input tensor device: {image_tensor.device}") | |
| print(f"DEBUG: Model device: {next(model.parameters()).device if list(model.parameters()) else 'N/A'}") | |
| print(f"DEBUG: ==============================================") | |
| # Run inference | |
| with torch.no_grad(): | |
| print(f"DEBUG: About to run inference with model type: {type(model)}") | |
| # CRITICAL: Compute hash of input tensor to verify it's different for different images | |
| import hashlib | |
| tensor_bytes = image_tensor.cpu().numpy().tobytes() | |
| tensor_hash = hashlib.md5(tensor_bytes).hexdigest() | |
| print(f"DEBUG: Input tensor hash (MD5): {tensor_hash}") | |
| print(f"DEBUG: Input tensor sample (first 10 values): {image_tensor[0, 0, :3, :3].cpu().numpy().flatten()[:10]}") | |
| # Verify model is actually in eval mode | |
| print(f"DEBUG: Model training mode: {model.training}") | |
| if model.training: | |
| print("ERROR: Model is in training mode! Setting to eval mode...") | |
| model.eval() | |
| predictions = model(image_tensor) | |
| # Compute hash of predictions to verify they're different | |
| pred_hash = hashlib.md5(predictions.cpu().numpy().tobytes()).hexdigest() | |
| print(f"DEBUG: Predictions hash (MD5): {pred_hash}") | |
| # CRITICAL: Check if predictions are all zeros or all the same | |
| import numpy as np | |
| pred_np = predictions.cpu().numpy()[0] | |
| if np.allclose(pred_np, 0): | |
| print(f"⚠️ ERROR: All predictions are zeros!") | |
| elif np.allclose(pred_np, pred_np[0]): | |
| print(f"⚠️ ERROR: All predictions are the same value: {pred_np[0]}") | |
| else: | |
| print(f"DEBUG: Predictions vary (min: {pred_np.min():.4f}, max: {pred_np.max():.4f}, std: {pred_np.std():.4f})") | |
| # Debug: Print raw predictions | |
| print(f"DEBUG: ========== MODEL INFERENCE OUTPUT ==========") | |
| print(f"DEBUG: Raw predictions shape: {predictions.shape}") | |
| print(f"DEBUG: Raw predictions (logits): {predictions.cpu().numpy()}") | |
| print(f"DEBUG: Raw predictions min: {predictions.min():.4f}, max: {predictions.max():.4f}") | |
| print(f"DEBUG: Model type: {type(model)}") | |
| print(f"DEBUG: Using class_names: {class_names}") | |
| print(f"DEBUG: Number of classes: {len(class_names)}") | |
| print(f"DEBUG: ============================================") | |
| # Process predictions | |
| if is_material_purity: | |
| # Binary classification - sigmoid output | |
| prob_pure = predictions[0, 0].item() if predictions.shape == (1, 1) else predictions.flatten()[0].item() | |
| prob_impure = 1.0 - prob_pure | |
| predicted_class = 1 if prob_pure >= 0.5 else 0 | |
| confidence = prob_pure if predicted_class == 1 else prob_impure | |
| top3_predictions = [ | |
| {'class': 'Pure', 'probability': prob_pure}, | |
| {'class': 'Impure', 'probability': prob_impure} | |
| ] | |
| top3_predictions.sort(key=lambda x: x['probability'], reverse=True) | |
| else: | |
| # Multi-class classification - softmax (EXACTLY as in eval script) | |
| # The eval script uses: prob = F.softmax(prediction, dim=1) | |
| import torch.nn.functional as F | |
| # CRITICAL: Verify predictions tensor is valid | |
| if predictions.shape[0] == 0 or predictions.shape[1] != len(class_names): | |
| print(f"ERROR: Invalid predictions shape: {predictions.shape}, expected [1, {len(class_names)}]") | |
| return jsonify({'error': f'Invalid model output shape: {predictions.shape}'}), 500 | |
| probabilities = F.softmax(predictions, dim=1)[0] # Get probabilities for first (and only) image | |
| # CRITICAL: Verify probabilities are valid | |
| if torch.any(torch.isnan(probabilities)) or torch.any(torch.isinf(probabilities)): | |
| print(f"ERROR: Invalid probabilities (NaN or Inf): {probabilities}") | |
| return jsonify({'error': 'Model produced invalid probabilities (NaN or Inf)'}), 500 | |
| # Get predicted class - ensure it's within valid range | |
| predicted_class = torch.argmax(probabilities, dim=0).item() | |
| if predicted_class < 0 or predicted_class >= len(class_names): | |
| print(f"ERROR: Predicted class index {predicted_class} out of range [0, {len(class_names)-1}]") | |
| return jsonify({'error': f'Invalid predicted class index: {predicted_class}'}), 500 | |
| confidence = probabilities[predicted_class].item() | |
| # CRITICAL: Double-check that predicted_class is actually the argmax | |
| actual_argmax = torch.argmax(probabilities).item() | |
| if predicted_class != actual_argmax: | |
| print(f"ERROR: predicted_class ({predicted_class}) != actual_argmax ({actual_argmax})") | |
| predicted_class = actual_argmax | |
| confidence = probabilities[predicted_class].item() | |
| # Debug: Print probabilities | |
| print(f"DEBUG: ========== MATERIAL DETECTION HEAD INFERENCE ==========") | |
| print(f"DEBUG: Raw predictions (logits): {predictions.cpu().numpy()}") | |
| print(f"DEBUG: Probabilities after softmax: {probabilities.tolist()}") | |
| print(f"DEBUG: Predicted class index: {predicted_class}") | |
| print(f"DEBUG: Class names: {class_names}") | |
| print(f"DEBUG: Class names length: {len(class_names)}") | |
| print(f"DEBUG: Predicted class name: {class_names[predicted_class] if predicted_class < len(class_names) else 'OUT_OF_RANGE'}") | |
| print(f"DEBUG: Confidence: {confidence:.6f}") | |
| print(f"DEBUG: All class probabilities: {probabilities.tolist()}") | |
| print(f"DEBUG: Expected class order (ImageFolder alphabetical): {class_names}") | |
| print(f"DEBUG: CRITICAL - Class index {predicted_class} maps to class name: {class_names[predicted_class] if predicted_class < len(class_names) else 'OUT_OF_RANGE'}") | |
| # Check if all probabilities are the same (would indicate a problem) | |
| prob_values = probabilities.tolist() | |
| if len(set([round(p, 6) for p in prob_values])) == 1: | |
| print(f"⚠️ WARNING: All probabilities are identical! This suggests a model or input issue.") | |
| elif max(prob_values) - min(prob_values) < 0.001: | |
| print(f"⚠️ WARNING: Probabilities are nearly identical (range < 0.001). This suggests a model or input issue.") | |
| # Check if predicted class is always 0 (3dmodel) - this might indicate an issue | |
| if predicted_class == 0: | |
| print(f"⚠️ WARNING: Predicted class is index 0 (3dmodel). This might be incorrect!") | |
| print(f"DEBUG: Probability for index 0 (3dmodel): {probabilities[0]:.6f}") | |
| print(f"DEBUG: Probability for index 1 (LEDscreen): {probabilities[1]:.6f}") | |
| print(f"DEBUG: Max probability index: {torch.argmax(probabilities).item()}") | |
| print(f"DEBUG: All probabilities: {[f'{i}:{p:.6f}' for i, p in enumerate(prob_values)]}") | |
| print(f"DEBUG: If this is an LED image, it should be predicted as 'LEDscreen' (index 1), not '3dmodel' (index 0)") | |
| print(f"DEBUG: Check if the STO file has the correct image at index 1 (16x16 material detection image)") | |
| print(f"DEBUG: ======================================================") | |
| # Get top 3 predictions (as in eval script: top_p, top_class = prob.topk(1, dim=1)) | |
| top3_probs, top3_indices = torch.topk(probabilities, min(3, len(class_names))) | |
| print(f"DEBUG: Top 3 indices: {top3_indices.tolist()}") | |
| print(f"DEBUG: Top 3 probabilities: {top3_probs.tolist()}") | |
| # Function to format class name for display (consolidate to materials only) | |
| def format_class_name(class_name): | |
| """Format class name to consolidated material format""" | |
| # Handle special cases | |
| if class_name == "3dmodel": | |
| return "3D Model" | |
| elif class_name == "LEDscreen": | |
| return "LED" | |
| elif "__" in class_name: | |
| # Format: "carrot__natural" -> "natural carrot" | |
| # Format: "bowl__purpleplastic" -> "purple plastic bowl" | |
| parts = class_name.split("__") | |
| if len(parts) == 2: | |
| material, type_ = parts | |
| # Handle compound words in type_ (e.g., "purpleplastic" -> "purple plastic") | |
| if type_ == "purpleplastic": | |
| type_ = "purple plastic" | |
| elif type_ == "whiteceramic": | |
| type_ = "white ceramic" | |
| # Lowercase and combine: "natural carrot" | |
| return f"{type_} {material}".lower() | |
| # Default: replace underscores and title case | |
| return class_name.replace("__", " ").replace("_", " ").title() | |
| # Get top 3 predictions | |
| top3_predictions = [] | |
| for i in range(min(3, len(class_names))): | |
| class_idx = top3_indices[i].item() | |
| prob = top3_probs[i].item() | |
| class_name = class_names[class_idx] if class_idx < len(class_names) else f'class_{class_idx}' | |
| display_class = format_class_name(class_name) | |
| top3_predictions.append({ | |
| 'class': class_name, # Keep original for internal use | |
| 'display_class': display_class, # Formatted for display | |
| 'probability': prob | |
| }) | |
| print(f"DEBUG: Top {i+1}: class_idx={class_idx}, class_name={class_name}, prob={prob:.4f}") | |
| inference_time = (time.time() - start_time) * 1000 | |
| # Convert image to base64 for display | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='JPEG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode() | |
| return jsonify({ | |
| 'success': True, | |
| 'predicted_class': class_names[predicted_class] if predicted_class < len(class_names) else 'unknown', | |
| 'confidence': confidence, | |
| 'top3_predictions': top3_predictions, | |
| 'inference_time': inference_time, | |
| 'image': f"data:image/jpeg;base64,{img_base64}", | |
| 'architecture': model_architecture, | |
| 'model_size': 'Variable', | |
| 'input_size': '16x16', | |
| 'batch_size': 1, | |
| 'classes': class_names, | |
| 'classes_display': ', '.join(class_names) | |
| }) | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"ERROR in api_detect_material_head: {str(e)}") | |
| print(f"Traceback: {error_traceback}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Detection failed: {str(e)}', | |
| 'error_type': 'detection_failed', | |
| 'traceback': error_traceback | |
| }), 500 | |
| def api_extract_sto_index0(): | |
| """API endpoint to extract index 1 (16x16 material detection image) from STO file | |
| Note: Despite the name 'index0', this extracts index 1 which is the material detection image. | |
| STO structure: Index 0=metadata, Index 1=16x16 material image, Index 2=OD metadata, Index 3=640x640 OD image""" | |
| import pickle | |
| import io | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| temp_path = None | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'success': False, 'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No file selected'}), 400 | |
| filename_lower = file.filename.lower() | |
| if not filename_lower.endswith('.sto'): | |
| return jsonify({'success': False, 'error': 'Please upload a STO file'}), 400 | |
| # Save and load STO file | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| temp_path = os.path.join(temp_dir, f"temp_sto_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.sto") | |
| file.save(temp_path) | |
| if not os.path.exists(temp_path): | |
| return jsonify({'success': False, 'error': 'Failed to save uploaded file'}), 500 | |
| # Load STO file | |
| with open(temp_path, 'rb') as f: | |
| sto_data = pickle.load(f) | |
| if not isinstance(sto_data, (list, tuple)): | |
| return jsonify({'success': False, 'error': f'Invalid STO file format: expected list/tuple, got {type(sto_data).__name__}'}), 400 | |
| if len(sto_data) == 0: | |
| return jsonify({'success': False, 'error': 'STO file is empty'}), 400 | |
| # Extract index 1 (16x16 material detection image) | |
| # STO structure: Index 0=metadata, Index 1=16x16 material image, Index 2=OD metadata, Index 3=640x640 OD image | |
| if len(sto_data) < 2: | |
| return jsonify({'success': False, 'error': 'STO file does not have index 1 (material detection image)'}), 400 | |
| sto_item = sto_data[1] | |
| image = None | |
| if isinstance(sto_item, bytes): | |
| image = Image.open(BytesIO(sto_item)).convert('RGB') | |
| elif isinstance(sto_item, Image.Image): | |
| image = sto_item.convert('RGB') | |
| elif hasattr(sto_item, 'mode'): | |
| # PIL Image object | |
| image = sto_item.convert('RGB') | |
| elif isinstance(sto_item, (list, tuple)) and len(sto_item) > 0: | |
| # Nested structure, try first element | |
| nested_item = sto_item[0] | |
| if isinstance(nested_item, bytes): | |
| image = Image.open(BytesIO(nested_item)).convert('RGB') | |
| elif isinstance(nested_item, Image.Image): | |
| image = nested_item.convert('RGB') | |
| else: | |
| return jsonify({'success': False, 'error': f'Invalid STO structure at index 1: unexpected type {type(sto_item).__name__}'}), 400 | |
| else: | |
| return jsonify({'success': False, 'error': f'Invalid STO structure at index 1: expected image, got {type(sto_item).__name__}'}), 400 | |
| if image is None: | |
| return jsonify({'success': False, 'error': 'Failed to extract image from STO file'}), 500 | |
| # Convert to base64 using PNG (lossless) to preserve pixel values for material detection | |
| # CRITICAL: JPEG compression causes pixel value changes that affect model predictions | |
| # PNG is lossless and preserves exact pixel values, especially important for 16x16 images | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='PNG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| return jsonify({ | |
| 'success': True, | |
| 'image': f"data:image/png;base64,{img_base64}", | |
| 'image_size': image.size, | |
| 'image_mode': image.mode | |
| }) | |
| except pickle.UnpicklingError as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Failed to unpickle STO file: {str(e)}', | |
| 'error_type': 'unpickling_failed' | |
| }), 500 | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'error_type': type(e).__name__ | |
| }), 500 | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except: | |
| pass | |
| def api_extract_sto_index1(): | |
| """API endpoint to extract index 3 (640x640 object detection image) from STO file | |
| Note: Despite the name 'index1', this extracts index 3 which is the object detection image. | |
| STO structure: Index 0=metadata, Index 1=16x16 material image, Index 2=OD metadata, Index 3=640x640 OD image""" | |
| import pickle | |
| import io | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| temp_path = None | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'success': False, 'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No file selected'}), 400 | |
| filename_lower = file.filename.lower() | |
| if not filename_lower.endswith('.sto'): | |
| return jsonify({'success': False, 'error': 'Please upload a STO file'}), 400 | |
| # Save and load STO file | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| temp_path = os.path.join(temp_dir, f"temp_sto_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.sto") | |
| file.save(temp_path) | |
| if not os.path.exists(temp_path): | |
| return jsonify({'success': False, 'error': 'Failed to save uploaded file'}), 500 | |
| # Load STO file | |
| with open(temp_path, 'rb') as f: | |
| sto_data = pickle.load(f) | |
| if not isinstance(sto_data, (list, tuple)): | |
| return jsonify({'success': False, 'error': f'Invalid STO file format: expected list/tuple, got {type(sto_data).__name__}'}), 400 | |
| if len(sto_data) < 4: | |
| return jsonify({'success': False, 'error': 'STO file does not have index 3 (object detection image)'}), 400 | |
| # Extract index 3 (640x640 object detection image) | |
| # STO structure: Index 0=metadata, Index 1=16x16 material image, Index 2=OD metadata, Index 3=640x640 OD image | |
| sto_item = sto_data[3] | |
| image = None | |
| if isinstance(sto_item, bytes): | |
| image = Image.open(BytesIO(sto_item)).convert('RGB') | |
| elif isinstance(sto_item, Image.Image): | |
| image = sto_item.convert('RGB') | |
| elif hasattr(sto_item, 'mode'): | |
| # PIL Image object | |
| image = sto_item.convert('RGB') | |
| elif isinstance(sto_item, (list, tuple)) and len(sto_item) > 0: | |
| # Nested structure, try first element | |
| nested_item = sto_item[0] | |
| if isinstance(nested_item, bytes): | |
| image = Image.open(BytesIO(nested_item)).convert('RGB') | |
| elif isinstance(nested_item, Image.Image): | |
| image = nested_item.convert('RGB') | |
| else: | |
| return jsonify({'success': False, 'error': f'Invalid STO structure at index 1: unexpected type {type(sto_item).__name__}'}), 400 | |
| else: | |
| return jsonify({'success': False, 'error': f'Invalid STO structure at index 1: unexpected type {type(sto_item).__name__}'}), 400 | |
| if image is None: | |
| return jsonify({'success': False, 'error': 'Failed to extract image from STO file'}), 500 | |
| # Convert to base64 | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='JPEG', quality=95) | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| return jsonify({ | |
| 'success': True, | |
| 'image': f"data:image/jpeg;base64,{img_base64}", | |
| 'image_size': image.size, | |
| 'image_mode': image.mode | |
| }) | |
| except pickle.UnpicklingError as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Failed to unpickle STO file: {str(e)}', | |
| 'error_type': 'unpickling_failed' | |
| }), 500 | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'error_type': type(e).__name__ | |
| }), 500 | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except: | |
| pass | |
| def api_yolov3_weights(): | |
| """API endpoint to get YOLOv3 model weights from local directory or Hugging Face Hub""" | |
| try: | |
| import glob | |
| import re | |
| yolov3_weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "yolov3") | |
| repo_id = MODEL_REPO_MAP.get('yolov3', 'mvplus/yolov3') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Parse YOLOv3 model filenames | |
| if "best" in filename.lower(): | |
| weight_type = "Best" | |
| acc_match = re.search(r'acc_(\d+\.?\d*)%', filename) | |
| if acc_match: | |
| acc_score = float(acc_match.group(1)) | |
| display_name = f"YOLOv3 Best ({acc_score}% accuracy)" | |
| else: | |
| display_name = f"YOLOv3 Best" | |
| elif "last" in filename.lower(): | |
| weight_type = "Last" | |
| acc_match = re.search(r'acc_(\d+\.?\d*)%', filename) | |
| if acc_match: | |
| acc_score = float(acc_match.group(1)) | |
| display_name = f"YOLOv3 Last ({acc_score}% accuracy)" | |
| else: | |
| display_name = f"YOLOv3 Last" | |
| else: | |
| weight_type = "Checkpoint" | |
| display_name = f"YOLOv3 Checkpoint ({filename.replace('.pt', '')})" | |
| yolov3_weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "source": "hub", | |
| "repo_url": f"https://huggingface.co/{repo_id}", | |
| "repo_id": repo_id | |
| }) | |
| # Also check local files | |
| # Look for .pt files in the weights directory | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pt")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Parse YOLOv3 model filenames | |
| if "best" in filename.lower(): | |
| weight_type = "Best" | |
| # Extract accuracy from filename if available | |
| acc_match = re.search(r'acc_(\d+\.?\d*)%', filename) | |
| if acc_match: | |
| acc_score = float(acc_match.group(1)) | |
| display_name = f"YOLOv3 Best ({acc_score}% accuracy)" | |
| else: | |
| display_name = f"YOLOv3 Best" | |
| elif "last" in filename.lower(): | |
| weight_type = "Last" | |
| # Extract accuracy from filename if available | |
| acc_match = re.search(r'acc_(\d+\.?\d*)%', filename) | |
| if acc_match: | |
| acc_score = float(acc_match.group(1)) | |
| display_name = f"YOLOv3 Last ({acc_score}% accuracy)" | |
| else: | |
| display_name = f"YOLOv3 Last" | |
| else: | |
| weight_type = "Checkpoint" | |
| display_name = f"YOLOv3 Checkpoint ({filename.replace('.pt', '')})" | |
| # Check if this weight is already in the list from Hub | |
| existing = next((w for w in yolov3_weights if w['filename'] == filename), None) | |
| if not existing: | |
| yolov3_weights.append({ | |
| "filename": filename, | |
| "path": weight_file, | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "source": "local" | |
| }) | |
| # Sort by weight type (Best first), then by filename | |
| yolov3_weights.sort(key=lambda x: (x['weight_type'] == 'Best', x['filename']), reverse=True) | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': yolov3_weights, | |
| 'repo_url': f"https://huggingface.co/{repo_id}", | |
| 'repo_id': repo_id | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading YOLOv3 weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_yolov8_custom_weights(): | |
| """API endpoint to get YOLOv8 custom model weights from local directory or Hugging Face Hub""" | |
| import os | |
| import glob | |
| import re | |
| try: | |
| yolov8_weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "yolov8") | |
| repo_id = MODEL_REPO_MAP.get('yolov8', 'mvplus/yolov8') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Parse YOLOv8 model filenames | |
| if "best" in filename.lower(): | |
| weight_type = "Best" | |
| map_match = re.search(r'(\d+\.?\d*)mAp', filename, re.IGNORECASE) | |
| if map_match: | |
| map_score = float(map_match.group(1)) | |
| display_name = f"YOLOv8 Best ({map_score}% mAP)" | |
| else: | |
| display_name = f"YOLOv8 Best" | |
| elif "last" in filename.lower(): | |
| weight_type = "Last" | |
| map_match = re.search(r'(\d+\.?\d*)mAp', filename, re.IGNORECASE) | |
| if map_match: | |
| map_score = float(map_match.group(1)) | |
| display_name = f"YOLOv8 Last ({map_score}% mAP)" | |
| else: | |
| display_name = f"YOLOv8 Last" | |
| else: | |
| weight_type = "Checkpoint" | |
| display_name = f"YOLOv8 Checkpoint ({filename.replace('.pt', '')})" | |
| yolov8_weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "source": "hub", | |
| "repo_url": f"https://huggingface.co/{repo_id}", | |
| "repo_id": repo_id | |
| }) | |
| # Also check local files | |
| # Look for .pt files in the weights directory | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pt")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Parse YOLOv8 model filenames | |
| if "best" in filename.lower(): | |
| weight_type = "Best" | |
| # Extract mAP from filename if available | |
| map_match = re.search(r'(\d+\.?\d*)mAp', filename, re.IGNORECASE) | |
| if map_match: | |
| map_score = float(map_match.group(1)) | |
| display_name = f"YOLOv8 {weight_type} ({map_score}% mAP)" | |
| else: | |
| display_name = f"YOLOv8 {weight_type}" | |
| elif "last" in filename.lower(): | |
| weight_type = "Last" | |
| # Extract mAP from filename if available | |
| map_match = re.search(r'(\d+\.?\d*)mAp', filename, re.IGNORECASE) | |
| if map_match: | |
| map_score = float(map_match.group(1)) | |
| display_name = f"YOLOv8 {weight_type} ({map_score}% mAP)" | |
| else: | |
| display_name = f"YOLOv8 {weight_type}" | |
| else: | |
| weight_type = "Checkpoint" | |
| display_name = f"YOLOv8 {weight_type} ({filename.replace('.pt', '')})" | |
| # Check if this weight is already in the list from Hub | |
| existing = next((w for w in yolov8_weights if w['filename'] == filename), None) | |
| if not existing: | |
| yolov8_weights.append({ | |
| "filename": filename, | |
| "path": weight_file, | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "source": "local" | |
| }) | |
| # Sort by weight type (Best first, then Last, then others) | |
| weight_priority = {"Best": 0, "Last": 1, "Checkpoint": 2} | |
| yolov8_weights.sort(key=lambda x: (weight_priority.get(x['weight_type'], 3), x['filename'])) | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': yolov8_weights, | |
| 'repo_url': f"https://huggingface.co/{repo_id}", | |
| 'repo_id': repo_id | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading YOLOv8 weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_dinov3_weights(): | |
| """API endpoint to get DINOv3 model weights from local directory or Hugging Face Hub""" | |
| import os | |
| import glob | |
| import re | |
| try: | |
| dinov3_weights = [] | |
| weights_dir = os.path.join(get_models_dir(), "dinov3") | |
| repo_id = MODEL_REPO_MAP.get('dinov3', 'mvplus/dinov3') | |
| subfolder = None # Individual repo, no subfolder | |
| # First, try to get weights from Hugging Face Hub | |
| hub_files = fetch_weights_from_hub(repo_id, subfolder) | |
| for hub_file in hub_files: | |
| filename = os.path.basename(hub_file) | |
| # Parse DINOv3 model filenames | |
| epoch_match = re.search(r'epoch_(\d+)', filename) | |
| acc_match = re.search(r'acc_(\d+)_(\d+)%', filename) | |
| epoch = int(epoch_match.group(1)) if epoch_match else 0 | |
| acc_whole = int(acc_match.group(1)) if acc_match else 0 | |
| acc_decimal = int(acc_match.group(2)) if acc_match and len(acc_match.groups()) > 1 else 0 | |
| accuracy = float(f"{acc_whole}.{acc_decimal}") if acc_match else 0.0 | |
| # Determine weight type | |
| if "best" in filename.lower() or accuracy >= 97.0: | |
| weight_type = "Best" | |
| elif "last" in filename.lower() or epoch >= 80: | |
| weight_type = "Last" | |
| else: | |
| weight_type = "Checkpoint" | |
| display_name = f"DINOv3 Epoch {epoch} ({accuracy}% accuracy)" if accuracy > 0 else f"DINOv3 Epoch {epoch}" | |
| dinov3_weights.append({ | |
| "filename": filename, | |
| "path": f"hub://{repo_id}/{hub_file}", | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "epoch": epoch, | |
| "accuracy": accuracy, | |
| "source": "hub", | |
| "repo_url": f"https://huggingface.co/{repo_id}", | |
| "repo_id": repo_id | |
| }) | |
| # Also check local files | |
| # Look for .pth files in the weights directory | |
| if os.path.exists(weights_dir): | |
| weight_files = glob.glob(os.path.join(weights_dir, "*.pth")) | |
| for weight_file in weight_files: | |
| filename = os.path.basename(weight_file) | |
| # Parse DINOv3 model filenames (e.g., "lastweight_epoch_82_train_0_0411_val_0_1622_acc_96_88%.pth") | |
| # Extract epoch, accuracy, and loss values | |
| epoch_match = re.search(r'epoch_(\d+)', filename) | |
| acc_match = re.search(r'acc_(\d+)_(\d+)%', filename) | |
| train_loss_match = re.search(r'train_([\d_]+)', filename) | |
| val_loss_match = re.search(r'val_([\d_]+)', filename) | |
| epoch = int(epoch_match.group(1)) if epoch_match else 0 | |
| acc_whole = int(acc_match.group(1)) if acc_match else 0 | |
| acc_decimal = int(acc_match.group(2)) if acc_match and len(acc_match.groups()) > 1 else 0 | |
| accuracy = float(f"{acc_whole}.{acc_decimal}") if acc_match else 0.0 | |
| # Determine weight type based on accuracy and epoch | |
| if "best" in filename.lower() or accuracy >= 97.0: | |
| weight_type = "Best" | |
| elif "last" in filename.lower() or epoch >= 80: | |
| weight_type = "Last" | |
| else: | |
| weight_type = "Checkpoint" | |
| # Create display name with accuracy | |
| if accuracy > 0: | |
| display_name = f"DINOv3 Epoch {epoch} ({accuracy}% accuracy)" | |
| else: | |
| display_name = f"DINOv3 Epoch {epoch}" | |
| # Check if this weight is already in the list from Hub | |
| existing = next((w for w in dinov3_weights if w['filename'] == filename), None) | |
| if not existing: | |
| dinov3_weights.append({ | |
| "filename": filename, | |
| "path": weight_file, | |
| "display_name": display_name, | |
| "weight_type": weight_type, | |
| "epoch": epoch, | |
| "accuracy": accuracy, | |
| "source": "local" | |
| }) | |
| # Sort by weight type (Best first, then Last, then by accuracy/epoch) | |
| weight_priority = {"Best": 0, "Last": 1, "Checkpoint": 2} | |
| dinov3_weights.sort(key=lambda x: ( | |
| weight_priority.get(x.get('weight_type', ''), 4), | |
| -x.get('accuracy', 0), # Higher accuracy first | |
| -x.get('epoch', 0) # Higher epoch first | |
| )) | |
| return jsonify({ | |
| 'success': True, | |
| 'weights': dinov3_weights, | |
| 'repo_url': f"https://huggingface.co/{repo_id}", | |
| 'repo_id': repo_id | |
| }) | |
| except Exception as e: | |
| import traceback | |
| print(f"Error loading DINOv3 weights: {e}") | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'weights': [] | |
| }), 500 | |
| def api_detect_dinov3(): | |
| """API endpoint for DINOv3 model detection - CLASSIFICATION""" | |
| import time | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import vit_b_16 | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Get selected weight | |
| weight_path = request.form.get('weight_path', '') | |
| print(f"DEBUG: Received weight_path: '{weight_path}'") | |
| if not weight_path: | |
| return jsonify({'error': 'No model weight selected'}), 400 | |
| # Handle Hub paths - download on demand | |
| if weight_path.startswith('hub://'): | |
| downloaded_path = download_model_from_hub(weight_path) | |
| if downloaded_path: | |
| weight_path = downloaded_path | |
| else: | |
| return jsonify({'error': f'Failed to download model from Hub: {weight_path}'}), 400 | |
| # Load image | |
| file.stream.seek(0) | |
| image = Image.open(file.stream).convert('RGB') | |
| print(f"DEBUG: Original image size: {image.size}") | |
| # Check if weight file exists | |
| if not os.path.exists(weight_path): | |
| print(f"DEBUG: Weight file does not exist: {weight_path}") | |
| return jsonify({'error': f'Model weight file not found: {weight_path}'}), 400 | |
| print(f"DEBUG: Weight file exists, size: {os.path.getsize(weight_path)} bytes") | |
| # Start timing | |
| start_time = time.time() | |
| try: | |
| # Import the DINOv3 model class | |
| import sys | |
| dinov3_path = os.path.join(BASE_DIR, "apps", "dinov3_custom") | |
| sys.path.append(dinov3_path) | |
| try: | |
| from train_dinov3 import DINOv3Model | |
| except ImportError as import_error: | |
| print(f"DEBUG: Failed to import DINOv3Model: {import_error}") | |
| # Define DINOv3Model inline if import fails | |
| class DINOv3Model(nn.Module): | |
| """DINOv3 model with Vision Transformer backbone""" | |
| def __init__(self, num_classes, pretrained_path=None): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.backbone = vit_b_16(pretrained=True) | |
| original_head = self.backbone.heads[0] | |
| self.backbone.heads = nn.Sequential( | |
| nn.Linear(original_head.in_features, num_classes) | |
| ) | |
| if pretrained_path and os.path.exists(pretrained_path): | |
| try: | |
| checkpoint = torch.load(pretrained_path, map_location='cpu') | |
| backbone_state = {} | |
| for key, value in checkpoint.items(): | |
| if not key.startswith('head'): | |
| backbone_state[key] = value | |
| self.backbone.load_state_dict(backbone_state, strict=False) | |
| except Exception as e: | |
| print(f"Warning: Could not load DINOv3 weights: {e}") | |
| def forward(self, x): | |
| return self.backbone(x) | |
| # Load checkpoint | |
| ckpt = torch.load(weight_path, map_location='cpu') | |
| print(f"DEBUG: Checkpoint keys: {list(ckpt.keys())}") | |
| # Infer num_classes from model state dict | |
| if 'model_state_dict' in ckpt: | |
| msd = ckpt['model_state_dict'] | |
| # Find the head layer to get num_classes | |
| for key in msd.keys(): | |
| if 'heads' in key and 'weight' in key: | |
| num_classes = msd[key].shape[0] | |
| print(f"DEBUG: Inferred num_classes from {key}: {num_classes}") | |
| break | |
| else: | |
| raise ValueError("Could not infer num_classes from model state dict") | |
| else: | |
| raise ValueError("Checkpoint does not contain model_state_dict") | |
| # Get class names from checkpoint or data directory | |
| if 'class_names' in ckpt: | |
| class_names = ckpt['class_names'] | |
| if isinstance(class_names, str): | |
| class_names = [class_names] | |
| else: | |
| # Try to load class names from data directory | |
| data_dir = os.path.join(BASE_DIR, "apps", "dinov3_custom", "data", "train") | |
| class_names = [] | |
| if os.path.exists(data_dir): | |
| from pathlib import Path | |
| class_set = set() | |
| for img_path in Path(data_dir).glob('*.jpg'): | |
| filename = img_path.stem | |
| class_name = filename.split('__')[0] | |
| class_set.add(class_name) | |
| class_names = sorted(list(class_set)) | |
| print(f"DEBUG: Loaded class names from data directory: {class_names}") | |
| # If still no class names, use mapping based on sorted alphabetical order | |
| # DINOv3 training uses sorted(class_set) which gives alphabetical order | |
| # This is different from YOLOv8's order! | |
| if not class_names or len(class_names) != num_classes: | |
| # Default class names in ALPHABETICAL ORDER (as DINOv3 training uses sorted()) | |
| # This matches how train_dinov3.py creates class_names: sorted(list(class_set)) | |
| # Actual classes from data directory (9 classes, sorted alphabetically): | |
| DEFAULT_CLASS_NAMES_ALPHABETICAL = [ | |
| 'bowl', # 0 | |
| 'carrot', # 1 | |
| 'eggplant', # 2 | |
| 'greenpepper', # 3 | |
| 'potato', # 4 | |
| 'redpepper', # 5 | |
| 'teacup', # 6 | |
| 'tomato', # 7 | |
| 'yellowpepper' # 8 | |
| ] | |
| # Use default names if we have the right number, otherwise use generic | |
| if num_classes <= len(DEFAULT_CLASS_NAMES_ALPHABETICAL): | |
| class_names = DEFAULT_CLASS_NAMES_ALPHABETICAL[:num_classes] | |
| debug_print(f"DEBUG: Using alphabetical default class names for DINOv3: {class_names}") | |
| else: | |
| class_names = [f'class_{i}' for i in range(num_classes)] | |
| debug_print(f"DEBUG: Warning - using generic class names: {class_names}") | |
| # Ensure class_names length matches num_classes | |
| if len(class_names) != num_classes: | |
| print(f"DEBUG: Warning - class_names length ({len(class_names)}) != num_classes ({num_classes}), adjusting...") | |
| if len(class_names) < num_classes: | |
| # Add generic names for missing classes | |
| for i in range(len(class_names), num_classes): | |
| class_names.append(f'class_{i}') | |
| else: | |
| # Truncate if too many | |
| class_names = class_names[:num_classes] | |
| print(f"DEBUG: DINOv3 - num_classes: {num_classes}, class_names: {class_names}") | |
| # Create model instance | |
| model = DINOv3Model(num_classes=num_classes) | |
| # Load weights | |
| if 'model_state_dict' in ckpt: | |
| model.load_state_dict(ckpt['model_state_dict']) | |
| elif 'state_dict' in ckpt: | |
| model.load_state_dict(ckpt['state_dict']) | |
| elif 'msd' in ckpt: | |
| model.load_state_dict(ckpt['msd']) | |
| else: | |
| # Try loading checkpoint as state_dict directly | |
| try: | |
| model.load_state_dict(ckpt) | |
| except: | |
| raise ValueError(f"Unknown checkpoint format. Keys: {list(ckpt.keys())}") | |
| model.eval() | |
| print(f"DEBUG: DINOv3 model loaded successfully") | |
| # Preprocess image - DINOv3 uses standard ImageNet normalization | |
| img_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img_tensor = img_transform(image).unsqueeze(0) | |
| # Run inference | |
| with torch.no_grad(): | |
| logits = model(img_tensor) | |
| print(f"DEBUG: Inference complete, logits shape: {logits.shape}") | |
| # Apply softmax to get probabilities | |
| import torch.nn.functional as F | |
| probabilities = F.softmax(logits, dim=1)[0] | |
| predicted_class = torch.argmax(probabilities, dim=0).item() | |
| confidence = probabilities[predicted_class].item() | |
| # Get top 3 predictions | |
| top3_probs, top3_indices = torch.topk(probabilities, min(3, num_classes)) | |
| top3_predictions = [] | |
| for i in range(min(3, num_classes)): | |
| class_idx = top3_indices[i].item() | |
| prob = top3_probs[i].item() | |
| top3_predictions.append({ | |
| 'class': class_names[class_idx] if class_idx < len(class_names) else f'class_{class_idx}', | |
| 'probability': float(prob) | |
| }) | |
| # Get all predictions | |
| all_predictions = [] | |
| for i in range(num_classes): | |
| all_predictions.append({ | |
| 'class': class_names[i] if i < len(class_names) else f'class_{i}', | |
| 'probability': float(probabilities[i].item()) | |
| }) | |
| all_predictions.sort(key=lambda x: x['probability'], reverse=True) | |
| inference_time = (time.time() - start_time) * 1000 | |
| # Convert image to base64 for display | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='JPEG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode() | |
| # Extract fitness/accuracy from checkpoint if available | |
| fitness_score = None | |
| if 'fitness' in ckpt: | |
| fitness_score = float(ckpt.get('fitness', 0.0)) | |
| elif 'best_fitness' in ckpt: | |
| fitness_score = float(ckpt.get('best_fitness', 0.0)) | |
| elif 'accuracy' in ckpt: | |
| fitness_score = float(ckpt.get('accuracy', 0.0)) | |
| return jsonify({ | |
| 'success': True, | |
| 'predicted_class': class_names[predicted_class] if predicted_class < len(class_names) else 'unknown', | |
| 'confidence': float(confidence), | |
| 'top3_predictions': top3_predictions, | |
| 'all_predictions': all_predictions, | |
| 'inference_time': float(inference_time), | |
| 'image': f"data:image/jpeg;base64,{img_base64}", | |
| 'architecture': 'DINOv3 (Vision Transformer)', | |
| 'model_size': '86M parameters (ViT-B/16)', | |
| 'input_size': '224x224', | |
| 'batch_size': 1, | |
| 'classes': class_names, | |
| 'classes_display': ', '.join(class_names), | |
| 'backbone': 'Vision Transformer (ViT-B/16)', | |
| 'detection_heads': '1', | |
| 'anchors': 'N/A (Classification)', | |
| 'fitness_score': fitness_score if fitness_score is not None else 'N/A', | |
| 'used_weight_path': weight_path, | |
| 'inferred_num_classes': int(num_classes) | |
| }) | |
| except Exception as model_error: | |
| print(f"DEBUG: Model loading/inference failed: {model_error}") | |
| import traceback | |
| traceback_str = traceback.format_exc() | |
| print(f"FULL TRACEBACK:\n{traceback_str}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(model_error), | |
| 'traceback': traceback_str | |
| }), 500 | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"ERROR in api_detect_dinov3: {str(e)}") | |
| print(f"Traceback: {error_traceback}") | |
| return jsonify({ | |
| 'error': f'Detection failed: {str(e)}', | |
| 'error_type': 'detection_failed' | |
| }), 500 | |
| def api_detect_yolov8_custom(): | |
| """API endpoint for YOLOv8 custom model detection - PROPER OBJECT DETECTION""" | |
| import time | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| import io | |
| from PIL import ImageDraw, ImageFont | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Get selected weight | |
| weight_path = request.form.get('weight_path', '') | |
| print(f"DEBUG: Received weight_path: '{weight_path}'") | |
| if not weight_path: | |
| return jsonify({'error': 'No model weight selected'}), 400 | |
| # Handle Hub paths (hub://repo_id/path/to/file) | |
| if weight_path.startswith('hub://'): | |
| # Download from Hub - pass the full hub path | |
| downloaded_path = download_model_from_hub(weight_path) | |
| if downloaded_path: | |
| weight_path = downloaded_path | |
| else: | |
| return jsonify({'error': f'Failed to download model from Hub: {weight_path}'}), 400 | |
| # Load image - ensure stream is at beginning | |
| file.stream.seek(0) | |
| image = Image.open(file.stream).convert('RGB') | |
| print(f"DEBUG: Original image size: {image.size}") | |
| # Check if weight file exists | |
| if not os.path.exists(weight_path): | |
| print(f"DEBUG: Weight file does not exist: {weight_path}") | |
| return jsonify({'error': f'Model weight file not found: {weight_path}'}), 400 | |
| print(f"DEBUG: Weight file exists, size: {os.path.getsize(weight_path)} bytes") | |
| # Start timing | |
| start_time = time.time() | |
| try: | |
| # Load the YOLOv8 model using Ultralytics | |
| # Handle potential DFLoss attribute errors from version mismatches | |
| try: | |
| model = YOLO(weight_path) | |
| debug_print(f"DEBUG: YOLOv8 model loaded successfully") | |
| except (AttributeError, ImportError, RuntimeError) as model_load_error: | |
| error_msg = str(model_load_error) | |
| if 'DFLoss' in error_msg or 'loss' in error_msg.lower(): | |
| # Try loading with weights_only to bypass loss function loading | |
| import torch | |
| debug_print(f"DEBUG: DFLoss error detected, trying alternative loading method: {error_msg}") | |
| try: | |
| # Load checkpoint and extract just the model weights | |
| checkpoint = torch.load(weight_path, map_location='cpu', weights_only=False) | |
| # Create a new YOLO model and load state dict | |
| # First, try to determine model size from checkpoint | |
| if 'model' in checkpoint: | |
| model = checkpoint['model'] | |
| elif 'state_dict' in checkpoint: | |
| # Create a base YOLOv8 model and load state dict | |
| model = YOLO('yolov8n.pt') # Use nano as base | |
| model.model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| else: | |
| # Last resort: try loading with ultralytics but catch the error | |
| raise model_load_error | |
| debug_print(f"DEBUG: YOLOv8 model loaded using alternative method") | |
| except Exception as alt_error: | |
| debug_print(f"DEBUG: Alternative loading also failed: {alt_error}") | |
| # Try one more time with a fresh YOLO instance | |
| model = YOLO(weight_path) | |
| else: | |
| raise model_load_error | |
| # Get model metadata | |
| num_classes = len(model.names) | |
| class_list = ', '.join(list(model.names.values())[:5]) + ('...' if num_classes > 5 else '') | |
| # Run inference - PROPER OBJECT DETECTION | |
| results = model(image, conf=0.01) # Lower threshold to detect more objects | |
| # Process detection results | |
| detections = [] | |
| all_class_confidences = {} | |
| # Class names from the model | |
| class_names = model.names | |
| print(f"DEBUG: Model class names: {class_names}") | |
| # Initialize all class confidences to 0 | |
| for class_id, class_name in class_names.items(): | |
| all_class_confidences[class_name] = 0.0 | |
| # Process each detection | |
| for result in results: | |
| if result.boxes is not None: | |
| boxes = result.boxes | |
| print(f"DEBUG: YOLOv8 found {len(boxes)} detections from model") | |
| for i in range(len(boxes)): | |
| # Get bounding box coordinates | |
| box = boxes.xyxy[i].cpu().numpy() # [x1, y1, x2, y2] | |
| confidence = boxes.conf[i].cpu().numpy() | |
| class_id = int(boxes.cls[i].cpu().numpy()) | |
| class_name = class_names[class_id] # Use class name directly from model | |
| # Add detection | |
| detections.append({ | |
| 'class': class_name, | |
| 'confidence': float(confidence), | |
| 'bbox': { | |
| 'x1': float(box[0]), | |
| 'y1': float(box[1]), | |
| 'x2': float(box[2]), | |
| 'y2': float(box[3]) | |
| } | |
| }) | |
| # Sort detections by confidence (highest first) | |
| detections.sort(key=lambda x: x.get('confidence', 0.0), reverse=True) | |
| # Aggregate detection confidences by class | |
| # Use max confidence per class (or sum, depending on what makes sense) | |
| # For YOLOv8, we'll use max confidence per class to represent the model's confidence in that class | |
| class_aggregated = {} | |
| for det in detections: | |
| class_name = det['class'] # Already using display name | |
| conf = det['confidence'] | |
| # Use max confidence for each class (represents best detection of that class) | |
| if class_name not in class_aggregated: | |
| class_aggregated[class_name] = 0.0 | |
| class_aggregated[class_name] = max(class_aggregated[class_name], conf) | |
| print(f"DEBUG: Raw aggregated class confidences (before normalization): {class_aggregated}") | |
| # Now normalize the aggregated class confidences to sum to 100% | |
| # This distributes the probability across detected classes | |
| total_class_confidence = sum(class_aggregated.values()) | |
| if total_class_confidence > 0: | |
| # Normalize so they sum to 1.0 (100%) | |
| for class_name in class_aggregated: | |
| class_aggregated[class_name] = float(class_aggregated[class_name] / total_class_confidence) | |
| else: | |
| # If all confidences are 0, distribute equally | |
| num_classes = len(class_aggregated) | |
| if num_classes > 0: | |
| equal_prob = 1.0 / num_classes | |
| for class_name in class_aggregated: | |
| class_aggregated[class_name] = float(equal_prob) | |
| # Verify sum is exactly 1.0 (100%) | |
| total_norm_conf = sum(class_aggregated.values()) | |
| if abs(total_norm_conf - 1.0) > 0.01 and len(class_aggregated) > 0: | |
| print(f"DEBUG: WARNING - Normalized class confidences sum to {total_norm_conf:.4f}, not 1.0. Re-normalizing...") | |
| if total_norm_conf > 0: | |
| for class_name in class_aggregated: | |
| class_aggregated[class_name] = float(class_aggregated[class_name] / total_norm_conf) | |
| print(f"DEBUG: Normalized aggregated class confidences (sum={sum(class_aggregated.values()):.4f}): {class_aggregated}") | |
| # Get ALL predictions from normalized class confidences | |
| all_predictions = [] | |
| if len(class_aggregated) > 0: | |
| # Sort classes by normalized confidence (highest first) | |
| sorted_class_aggregated = sorted(class_aggregated.items(), key=lambda x: x[1], reverse=True) | |
| # Create predictions for ALL detected classes | |
| # These are already normalized to sum to 100% | |
| for class_name, normalized_conf in sorted_class_aggregated: | |
| all_predictions.append({ | |
| 'class': class_name, | |
| 'probability': float(normalized_conf) | |
| }) | |
| else: | |
| # No detections, use all_class_confidences (all zeros) | |
| sorted_classes = sorted(all_class_confidences.items(), key=lambda x: x[1], reverse=True) | |
| for class_name, prob in sorted_classes: | |
| all_predictions.append({ | |
| 'class': class_name, | |
| 'probability': float(prob) | |
| }) | |
| # Get top 3 predictions for display | |
| top3_predictions = all_predictions[:3] | |
| all_predictions_debug = [(p['class'], f"{p['probability']*100:.2f}%") for p in all_predictions[:5]] | |
| top3_debug = [(p['class'], f"{p['probability']*100:.2f}%") for p in top3_predictions] | |
| print(f"DEBUG: All predictions (sum={sum(p['probability'] for p in all_predictions):.4f}): {all_predictions_debug}") | |
| print(f"DEBUG: Top 3 predictions: {top3_debug}") | |
| # Ensure first detection matches first prediction | |
| if len(detections) > 0 and len(all_predictions) > 0: | |
| first_pred_class = all_predictions[0]['class'] | |
| first_pred_conf = all_predictions[0]['probability'] | |
| # Find matching detection | |
| matching_detection = None | |
| for det in detections: | |
| if det.get('class') == first_pred_class: | |
| matching_detection = det | |
| break | |
| if matching_detection: | |
| if detections[0].get('class') != first_pred_class: | |
| detections.remove(matching_detection) | |
| detections.insert(0, matching_detection) | |
| detections[0]['confidence'] = float(first_pred_conf) | |
| else: | |
| if len(detections) > 0: | |
| detections[0]['class'] = first_pred_class | |
| detections[0]['confidence'] = float(first_pred_conf) | |
| # Draw bounding boxes | |
| try: | |
| result_image = image.copy() | |
| draw = ImageDraw.Draw(result_image) | |
| try: | |
| font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20) | |
| except: | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| except Exception as draw_error: | |
| print(f"DEBUG: Error creating draw context: {draw_error}") | |
| result_image = image.copy() # Use original image if drawing fails | |
| draw = None | |
| # Draw only the highest confidence detection | |
| if len(detections) > 0: | |
| detection = detections[0] | |
| bbox = detection.get('bbox') | |
| class_name = detection.get('class', 'unknown') | |
| confidence = detection.get('confidence', 0.0) | |
| if isinstance(bbox, dict) and all(key in bbox for key in ['x1', 'y1', 'x2', 'y2']): | |
| x1 = float(bbox['x1']) | |
| y1 = float(bbox['y1']) | |
| x2 = float(bbox['x2']) | |
| y2 = float(bbox['y2']) | |
| if x1 > x2: | |
| x1, x2 = x2, x1 | |
| if y1 > y2: | |
| y1, y2 = y2, y1 | |
| x1 = max(0, min(image.width, x1)) | |
| y1 = max(0, min(image.height, y1)) | |
| x2 = max(0, min(image.width, x2)) | |
| y2 = max(0, min(image.height, y2)) | |
| if x1 < x2 and y1 < y2: | |
| draw.rectangle([x1, y1, x2, y2], outline='red', width=3) | |
| label = f"{class_name}: {confidence * 100:.2f}%" | |
| # Use textsize for compatibility (textbbox is newer) | |
| try: | |
| # Try textbbox first (Pillow >= 9.2.0) | |
| label_bbox = draw.textbbox((x1, y1 - 25), label, font=font) | |
| except AttributeError: | |
| # Fallback to textsize for older Pillow versions | |
| text_width, text_height = draw.textsize(label, font=font) | |
| label_bbox = (x1, y1 - 25, x1 + text_width, y1 - 25 + text_height) | |
| draw.rectangle(label_bbox, fill='red') | |
| draw.text((x1, y1 - 25), label, fill='white', font=font) | |
| # Convert to base64 | |
| try: | |
| img_buffer = io.BytesIO() | |
| result_image.save(img_buffer, format='JPEG', quality=95) | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode() | |
| except Exception as img_error: | |
| print(f"DEBUG: Error encoding image: {img_error}") | |
| import traceback | |
| traceback.print_exc() | |
| # Fallback: use original image | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='JPEG', quality=95) | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode() | |
| inference_time = (time.time() - start_time) * 1000 | |
| return jsonify({ | |
| 'success': True, | |
| 'detections': detections, | |
| 'top3_predictions': top3_predictions, | |
| 'all_predictions': all_predictions, | |
| 'inference_time': float(inference_time), | |
| 'image': f"data:image/jpeg;base64,{img_base64}", | |
| 'filename': str(file.filename) if file.filename else 'unknown', | |
| 'architecture': 'YOLOv8 Object Detection', | |
| 'model_size': '25.9M parameters', | |
| 'input_size': '640x640', | |
| 'batch_size': 1, | |
| 'used_weight_path': weight_path, | |
| 'inferred_num_classes': int(num_classes), | |
| 'model_type': 'YOLOv8', | |
| 'backbone': 'CSPDarknet', | |
| 'detection_heads': '1', | |
| 'anchors': '3 scales', | |
| 'classes': f"{num_classes} ({class_list})", | |
| 'fitness_score': 'N/A' | |
| }) | |
| except Exception as model_error: | |
| print(f"DEBUG: Model loading/inference failed: {model_error}") | |
| import traceback | |
| traceback_str = traceback.format_exc() | |
| print(f"FULL TRACEBACK:\n{traceback_str}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(model_error), | |
| 'traceback': traceback_str | |
| }), 500 | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"ERROR in api_detect_yolov8_custom: {str(e)}") | |
| print(f"Traceback: {error_traceback}") | |
| return jsonify({ | |
| 'error': f'Detection failed: {str(e)}', | |
| 'error_type': 'detection_failed' | |
| }), 500 | |
| def api_detect_yolov3(): | |
| """API endpoint for YOLOv3 model detection - PROPER OBJECT DETECTION""" | |
| import time | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| import io | |
| import torch | |
| from torchvision import transforms | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Get selected weight | |
| weight_path = request.form.get('weight_path', '') | |
| print(f"DEBUG: Received weight_path: '{weight_path}'") | |
| if not weight_path: | |
| return jsonify({'error': 'No model weight selected'}), 400 | |
| # Handle Hub paths - download on demand | |
| if weight_path.startswith('hub://'): | |
| downloaded_path = download_model_from_hub(weight_path) | |
| if downloaded_path: | |
| weight_path = downloaded_path | |
| else: | |
| return jsonify({'error': f'Failed to download model from Hub: {weight_path}'}), 400 | |
| # Load image | |
| image = Image.open(file.stream).convert('RGB') | |
| print(f"DEBUG: Original image size: {image.size}") | |
| # Check if weight file exists | |
| if not os.path.exists(weight_path): | |
| print(f"DEBUG: Weight file does not exist: {weight_path}") | |
| return jsonify({'error': f'Model weight file not found: {weight_path}'}), 400 | |
| print(f"DEBUG: Weight file exists, size: {os.path.getsize(weight_path)} bytes") | |
| # Start timing | |
| start_time = time.time() | |
| # Initialize fitness_score variable (will be set if found) | |
| fitness_score = None | |
| try: | |
| # Import the custom model architecture | |
| import sys | |
| yolov3_path = os.path.join(BASE_DIR, "apps", "yolov3_custom") | |
| sys.path.append(yolov3_path) | |
| try: | |
| from train_fast_yolov3 import UltraFastYOLOv3Model | |
| except ImportError as import_error: | |
| print(f"DEBUG: Failed to import UltraFastYOLOv3Model: {import_error}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Failed to import model: {import_error}' | |
| }), 500 | |
| # Load checkpoint - handle old module path references (e.g., 'models.yolo', 'models.common') | |
| import sys | |
| import pickle | |
| # Create comprehensive dummy module structure to satisfy old checkpoint imports | |
| # Make it compatible with Werkzeug reloader by ensuring __name__ is a string | |
| class DummyModule: | |
| """Dummy module that can handle any attribute access""" | |
| __name__ = 'models' # Set __name__ to prevent reloader issues | |
| __file__ = None | |
| def __getattr__(self, name): | |
| # Return a dummy class for any attribute access | |
| class DummyClass: | |
| __name__ = name # Set __name__ to string to prevent reloader errors | |
| __file__ = None | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def __call__(self, *args, **kwargs): | |
| return self | |
| def __getattr__(self, name): | |
| new_class = type(name, (DummyClass,), {'__name__': name, '__file__': None}) | |
| return new_class() | |
| return DummyClass() | |
| # Create models module and all common submodules | |
| # Set __name__ and __file__ to prevent Werkzeug reloader issues | |
| if 'models' not in sys.modules: | |
| dummy_models = DummyModule() | |
| dummy_models.__name__ = 'models' | |
| dummy_models.__file__ = None | |
| sys.modules['models'] = dummy_models | |
| # Create all possible models submodules that might be referenced | |
| models_submodules = ['yolo', 'common', 'experimental', 'utils', 'loss', 'head', 'backbone'] | |
| for submod in models_submodules: | |
| module_name = f'models.{submod}' | |
| if module_name not in sys.modules: | |
| dummy_submod = DummyModule() | |
| dummy_submod.__name__ = module_name | |
| dummy_submod.__file__ = None | |
| sys.modules[module_name] = dummy_submod | |
| # Custom unpickler that handles persistent IDs | |
| class SafeUnpickler(pickle.Unpickler): | |
| def persistent_load(self, pid): | |
| # Return None for any persistent ID to bypass module loading | |
| return None | |
| try: | |
| # First try normal torch.load | |
| ckpt = torch.load(weight_path, map_location='cpu', weights_only=False) | |
| except (ModuleNotFoundError, ImportError, AttributeError) as e: | |
| print(f"DEBUG: Warning - checkpoint import error (may be from old module path): {e}") | |
| # Try loading with custom unpickler | |
| try: | |
| with open(weight_path, 'rb') as f: | |
| unpickler = SafeUnpickler(f) | |
| ckpt = unpickler.load() | |
| except Exception as pickle_error: | |
| print(f"DEBUG: Safe unpickler also failed: {pickle_error}") | |
| # Last resort: try loading with weights_only=True (if supported) | |
| try: | |
| ckpt = torch.load(weight_path, map_location='cpu', weights_only=True) | |
| except: | |
| raise ValueError(f"Failed to load checkpoint: {e}. Pickle error: {pickle_error}") | |
| print(f"DEBUG: Custom checkpoint keys: {list(ckpt.keys())}") | |
| # Try to extract fitness score from checkpoint or filename | |
| if 'fitness' in ckpt: | |
| fitness_score = float(ckpt.get('fitness', 0.0)) | |
| print(f"DEBUG: Found fitness in checkpoint: {fitness_score}") | |
| elif 'best_fitness' in ckpt: | |
| fitness_score = float(ckpt.get('best_fitness', 0.0)) | |
| print(f"DEBUG: Found best_fitness in checkpoint: {fitness_score}") | |
| else: | |
| # Try to extract from filename | |
| import re | |
| filename = os.path.basename(weight_path) | |
| fitness_match = re.search(r'fitness_([\d\.]+)', filename) | |
| if fitness_match: | |
| fitness_score = float(fitness_match.group(1)) | |
| print(f"DEBUG: Extracted fitness from filename: {fitness_score}") | |
| # Get number of classes and class names from checkpoint | |
| class_names = ckpt.get('class_names', ['object'] * 80) | |
| if isinstance(class_names, str): | |
| class_names = [class_names] | |
| num_classes = len(class_names) | |
| print(f"DEBUG: Custom YOLOv3 - num_classes: {num_classes}, class_names: {class_names}") | |
| # Create model instance | |
| model = UltraFastYOLOv3Model(num_classes=num_classes) | |
| # Load weights | |
| if 'model_state_dict' not in ckpt: | |
| raise KeyError(f"Checkpoint does not contain 'model_state_dict' key. Available keys: {list(ckpt.keys())}") | |
| model.load_state_dict(ckpt['model_state_dict']) | |
| model.eval() | |
| print(f"DEBUG: Custom YOLOv3 model loaded successfully") | |
| # Preprocess image | |
| img_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img_tensor = img_transform(image).unsqueeze(0) | |
| # Run inference | |
| with torch.no_grad(): | |
| output = model(img_tensor) # [1, 3, 5+num_classes, 7, 7] | |
| results = {'output': output, 'model': model, 'class_names': class_names} | |
| print(f"DEBUG: Inference complete, output shape: {output.shape}") | |
| # Process detection results from custom YOLOv3 output | |
| detections = [] | |
| all_class_confidences = {} | |
| # Initialize all class confidences to 0 | |
| for class_name in class_names: | |
| all_class_confidences[class_name] = 0.0 | |
| # Process the model output tensor | |
| output = results['output'] | |
| # Convert to numpy if it's a tensor | |
| if hasattr(output, 'cpu'): | |
| output = output.cpu().numpy() | |
| elif hasattr(output, 'numpy'): | |
| output = output.numpy() | |
| print(f"DEBUG: Output numpy shape: {output.shape}") | |
| # Validate output shape | |
| if len(output.shape) != 5: | |
| raise ValueError(f"Expected output shape [1, 3, 5+num_classes, 7, 7], got {output.shape}") | |
| # Extract objectness and class scores | |
| objectness = output[0, :, 4, :, :] # [3, 7, 7] | |
| class_scores = output[0, :, 5:, :, :] # [3, num_classes, 7, 7] | |
| # Process detections with proper thresholds | |
| scale_x = image.width / 224.0 | |
| scale_y = image.height / 224.0 | |
| cell_width = 224.0 / 7.0 | |
| cell_height = 224.0 / 7.0 | |
| # Apply sigmoid to objectness (raw logits) | |
| objectness_sigmoid = 1.0 / (1.0 + np.exp(-objectness)) # [3, 7, 7] | |
| # Apply softmax to class scores (raw logits) per anchor and position | |
| # class_scores shape: [3, num_classes, 7, 7] | |
| # We need to apply softmax across classes for each (anchor, y, x) position | |
| class_scores_reshaped = class_scores.transpose(0, 2, 3, 1) # [3, 7, 7, num_classes] | |
| class_scores_flat = class_scores_reshaped.reshape(-1, len(class_names)) # [3*7*7, num_classes] | |
| # Apply softmax across classes | |
| max_scores = np.max(class_scores_flat, axis=1, keepdims=True) | |
| exp_scores = np.exp(class_scores_flat - max_scores) | |
| class_probs_flat = exp_scores / (np.sum(exp_scores, axis=1, keepdims=True) + 1e-8) | |
| class_probs = class_probs_flat.reshape(3, 7, 7, len(class_names)) # [3, 7, 7, num_classes] | |
| # Aggregate class probabilities using objectness-weighted mean | |
| # Weight by objectness to give more importance to cells with high objectness | |
| objectness_flat = objectness_sigmoid.flatten() # [3*7*7] | |
| class_probs_flat_for_agg = class_probs.reshape(-1, len(class_names)) # [3*7*7, num_classes] | |
| # Weighted average across all positions, weighted by objectness | |
| if np.sum(objectness_flat) > 0: | |
| aggregated_class_probs = np.average(class_probs_flat_for_agg, axis=0, weights=objectness_flat) | |
| else: | |
| aggregated_class_probs = np.mean(class_probs_flat_for_agg, axis=0) | |
| # Normalize to ensure probabilities sum to 1 | |
| total_prob = np.sum(aggregated_class_probs) | |
| if total_prob > 0: | |
| normalized_class_probs = aggregated_class_probs / total_prob | |
| else: | |
| normalized_class_probs = aggregated_class_probs | |
| # Store normalized probabilities | |
| for i, class_name in enumerate(class_names): | |
| all_class_confidences[class_name] = float(normalized_class_probs[i]) | |
| # Process detections with proper confidence threshold | |
| flat_objectness = objectness_sigmoid.flatten() | |
| all_indices = np.argsort(flat_objectness)[::-1] # Sort all by objectness | |
| # Use reasonable confidence threshold to avoid false positives | |
| confidence_threshold = 0.25 # Increased from 0.001 to reduce false positives | |
| for idx in all_indices: | |
| flat_idx = idx | |
| anchor_idx = flat_idx // (7 * 7) | |
| pos_in_grid = flat_idx % (7 * 7) | |
| y_idx = pos_in_grid // 7 | |
| x_idx = pos_in_grid % 7 | |
| # Get objectness (already sigmoided) | |
| confidence = float(objectness_sigmoid[anchor_idx, y_idx, x_idx]) | |
| # Get class probabilities for this position | |
| anchor_class_probs = class_probs[anchor_idx, y_idx, x_idx, :] # [num_classes] | |
| best_class_idx = int(np.argmax(anchor_class_probs)) | |
| best_class_prob = float(anchor_class_probs[best_class_idx]) | |
| # Combined confidence: objectness * class probability | |
| combined_confidence = confidence * best_class_prob | |
| # Use stricter thresholds to reduce false positives | |
| objectness_threshold = 0.3 # Increased from 0.1 | |
| # Only keep detections with high combined confidence | |
| if combined_confidence > confidence_threshold and confidence > objectness_threshold and best_class_idx < len(class_names): | |
| class_name = class_names[best_class_idx] | |
| # Calculate bbox | |
| x_center_cell = (x_idx + 0.5) * cell_width | |
| y_center_cell = (y_idx + 0.5) * cell_height | |
| bbox_width = cell_width * 1.5 | |
| bbox_height = cell_height * 1.5 | |
| x1_model = x_center_cell - bbox_width / 2 | |
| y1_model = y_center_cell - bbox_height / 2 | |
| x2_model = x_center_cell + bbox_width / 2 | |
| y2_model = y_center_cell + bbox_height / 2 | |
| # Scale to actual image dimensions | |
| x1 = x1_model * scale_x | |
| y1 = y1_model * scale_y | |
| x2 = x2_model * scale_x | |
| y2 = y2_model * scale_y | |
| # Clamp to image bounds | |
| x1 = max(0, min(image.width, x1)) | |
| y1 = max(0, min(image.height, y1)) | |
| x2 = max(0, min(image.width, x2)) | |
| y2 = max(0, min(image.height, y2)) | |
| if x1 < x2 and y1 < y2: | |
| detections.append({ | |
| 'class': str(class_name), | |
| 'confidence': float(combined_confidence), | |
| 'bbox': { | |
| 'x1': float(x1), 'y1': float(y1), | |
| 'x2': float(x2), 'y2': float(y2) | |
| } | |
| }) | |
| # Apply Non-Maximum Suppression (NMS) to filter overlapping detections | |
| # This helps remove duplicate detections of the same object | |
| def calculate_iou(bbox1, bbox2): | |
| """Calculate Intersection over Union (IoU) between two bounding boxes""" | |
| x1_1, y1_1, x2_1, y2_1 = bbox1['x1'], bbox1['y1'], bbox1['x2'], bbox1['y2'] | |
| x1_2, y1_2, x2_2, y2_2 = bbox2['x1'], bbox2['y1'], bbox2['x2'], bbox2['y2'] | |
| # Calculate intersection | |
| x1_i = max(x1_1, x1_2) | |
| y1_i = max(y1_1, y1_2) | |
| x2_i = min(x2_1, x2_2) | |
| y2_i = min(y2_1, y2_2) | |
| if x2_i <= x1_i or y2_i <= y1_i: | |
| return 0.0 | |
| intersection = (x2_i - x1_i) * (y2_i - y1_i) | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union = area1 + area2 - intersection | |
| if union == 0: | |
| return 0.0 | |
| return intersection / union | |
| # Sort detections by confidence (highest first) | |
| detections.sort(key=lambda x: x.get('confidence', 0.0), reverse=True) | |
| # Apply NMS: keep detections with highest confidence, remove overlapping ones | |
| nms_threshold = 0.5 # IoU threshold for NMS | |
| filtered_detections = [] | |
| for det in detections: | |
| should_keep = True | |
| for kept_det in filtered_detections: | |
| # Only apply NMS to detections of the same class | |
| if det.get('class') == kept_det.get('class'): | |
| iou = calculate_iou(det['bbox'], kept_det['bbox']) | |
| if iou > nms_threshold: | |
| should_keep = False | |
| break | |
| if should_keep: | |
| filtered_detections.append(det) | |
| detections = filtered_detections | |
| # Get all predictions | |
| all_predictions = [] | |
| sorted_classes = sorted(all_class_confidences.items(), key=lambda x: x[1], reverse=True) | |
| for class_name, prob in sorted_classes: | |
| all_predictions.append({ | |
| 'class': class_name, | |
| 'probability': float(prob) | |
| }) | |
| # Get top 3 predictions | |
| top3_predictions = all_predictions[:3] | |
| # Ensure the first detection matches the first prediction (highest confidence class) | |
| # This ensures consistency between the detection class and top prediction | |
| if len(detections) > 0 and len(all_predictions) > 0: | |
| first_pred_class = all_predictions[0]['class'] | |
| first_pred_conf = all_predictions[0]['probability'] | |
| print(f"DEBUG: First prediction (highest confidence from all_class_confidences): {first_pred_class} = {first_pred_conf*100:.2f}%") | |
| # Find the detection with the first prediction class | |
| matching_detection = None | |
| for det in detections: | |
| if det.get('class') == first_pred_class: | |
| matching_detection = det | |
| break | |
| if matching_detection: | |
| # Move matching detection to first position | |
| if detections[0].get('class') != first_pred_class: | |
| detections.remove(matching_detection) | |
| detections.insert(0, matching_detection) | |
| print(f"DEBUG: Moved detection with class '{first_pred_class}' to first position") | |
| # Update detection confidence to match first prediction's confidence | |
| detections[0]['confidence'] = float(first_pred_conf) | |
| print(f"DEBUG: Updated first detection confidence to match first prediction: {first_pred_class} = {first_pred_conf*100:.2f}%") | |
| else: | |
| # No detection found for first prediction class, update first detection | |
| print(f"DEBUG: No detection found for first prediction class '{first_pred_class}', updating first detection") | |
| if len(detections) > 0: | |
| detections[0]['class'] = first_pred_class | |
| detections[0]['confidence'] = float(first_pred_conf) | |
| print(f"DEBUG: Updated first detection to class '{first_pred_class}' with confidence {first_pred_conf*100:.2f}%") | |
| print(f"DEBUG: Final detections count: {len(detections)}") | |
| print(f"DEBUG: Top 3 predictions: {top3_predictions}") | |
| if len(detections) > 0: | |
| print(f"DEBUG: First detection: {detections[0].get('class')} = {detections[0].get('confidence')*100:.2f}%") | |
| if len(all_predictions) > 0: | |
| print(f"DEBUG: First prediction: {all_predictions[0]['class']} = {all_predictions[0]['probability']*100:.2f}%") | |
| # Draw bounding boxes | |
| from PIL import ImageDraw, ImageFont | |
| result_image = image.copy() | |
| draw = ImageDraw.Draw(result_image) | |
| try: | |
| font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| # Draw only the highest confidence detection | |
| if len(detections) > 0: | |
| detection = detections[0] | |
| bbox = detection.get('bbox') | |
| class_name = detection.get('class', 'unknown') | |
| confidence = detection.get('confidence', 0.0) | |
| if isinstance(bbox, dict) and all(key in bbox for key in ['x1', 'y1', 'x2', 'y2']): | |
| x1 = float(bbox['x1']) | |
| y1 = float(bbox['y1']) | |
| x2 = float(bbox['x2']) | |
| y2 = float(bbox['y2']) | |
| if x1 > x2: | |
| x1, x2 = x2, x1 | |
| if y1 > y2: | |
| y1, y2 = y2, y1 | |
| x1 = max(0, min(image.width, x1)) | |
| y1 = max(0, min(image.height, y1)) | |
| x2 = max(0, min(image.width, x2)) | |
| y2 = max(0, min(image.height, y2)) | |
| if x1 < x2 and y1 < y2: | |
| draw.rectangle([x1, y1, x2, y2], outline='red', width=3) | |
| label = f"{class_name}: {confidence * 100:.2f}%" | |
| # Use textsize for compatibility (textbbox is newer) | |
| try: | |
| # Try textbbox first (Pillow >= 9.2.0) | |
| label_bbox = draw.textbbox((x1, y1 - 25), label, font=font) | |
| except AttributeError: | |
| # Fallback to textsize for older Pillow versions | |
| text_width, text_height = draw.textsize(label, font=font) | |
| label_bbox = (x1, y1 - 25, x1 + text_width, y1 - 25 + text_height) | |
| draw.rectangle(label_bbox, fill='red') | |
| draw.text((x1, y1 - 25), label, fill='white', font=font) | |
| # Convert to base64 | |
| img_buffer = io.BytesIO() | |
| result_image.save(img_buffer, format='JPEG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode() | |
| inference_time = (time.time() - start_time) * 1000 | |
| # Format fitness score | |
| fitness_display = 'N/A' | |
| if fitness_score is not None: | |
| fitness_display = f"{fitness_score:.4f}" | |
| return jsonify({ | |
| 'success': True, | |
| 'detections': detections, | |
| 'top3_predictions': top3_predictions, | |
| 'all_predictions': all_predictions, | |
| 'inference_time': float(inference_time), | |
| 'image': f"data:image/jpeg;base64,{img_base64}", | |
| 'filename': str(file.filename) if file.filename else 'unknown', | |
| 'architecture': 'YOLOv3 Object Detection', | |
| 'model_size': '246M parameters', | |
| 'input_size': '224x224', | |
| 'batch_size': 1, | |
| 'used_weight_path': weight_path, | |
| 'inferred_num_classes': int(num_classes), | |
| 'backbone': 'Darknet-53', | |
| 'detection_heads': '3 scales', | |
| 'anchors': '9 anchors (3 per scale)', | |
| 'fitness_score': fitness_display | |
| }) | |
| except Exception as model_error: | |
| print(f"DEBUG: Model loading/inference failed: {model_error}") | |
| import traceback | |
| traceback_str = traceback.format_exc() | |
| print(f"FULL TRACEBACK:\n{traceback_str}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(model_error), | |
| 'traceback': traceback_str | |
| }), 500 | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"ERROR in api_detect_yolov3: {str(e)}") | |
| print(f"Traceback: {error_traceback}") | |
| return jsonify({ | |
| 'error': f'Detection failed: {str(e)}', | |
| 'error_type': 'detection_failed' | |
| }), 500 | |
| # Mapping of model types to their Hugging Face repositories | |
| MODEL_REPO_MAP = { | |
| 'dinov3': 'mvplus/dinov3', | |
| 'yolov3': 'mvplus/yolov3', | |
| 'yolov8': 'mvplus/yolov8', | |
| 'flat_surface': 'mvplus/flat_surface', | |
| 'material_purity': 'mvplus/material_purity', | |
| 'material_detection_head': 'mvplus/material_detection_head', | |
| 'spatiotemporal': 'mvplus/spatiotemporal' | |
| } | |
| # Mapping of dataset types to their Hugging Face repositories | |
| DATASET_REPO_MAP = { | |
| 'flat_surface_detection': 'mvplus/testmages__flatsurface', | |
| 'fluid_purity_demo': 'mvplus/testmages__milkpurity', | |
| 'dinov3_demo': 'mvplus/testmages_dino', | |
| 'custom_yolov8_demo': 'mvplus/testmages__yolov8', | |
| 'spatiotemporal_detection': 'mvplus/testmages_spatiotemporal', | |
| 'detect_yolov3': 'mvplus/testmages__yolov3', | |
| 'material_detection_head': 'mvplus/val_natural_material_detection' | |
| } | |
| # Map app names to their testimages directories (datasets is sibling to BASE_DIR) | |
| # Lazy evaluation of TESTIMAGES_DIRS to avoid issues if directories don't exist at startup | |
| def download_dataset_from_hub(dataset_repo_id): | |
| """Download a dataset from Hugging Face Hub on demand | |
| Args: | |
| dataset_repo_id: Repository ID (e.g., 'mvplus/testmages_spatiotemporal') | |
| Returns: | |
| Local directory path if successful, None otherwise | |
| """ | |
| try: | |
| from huggingface_hub import snapshot_download | |
| import os | |
| datasets_dir = get_datasets_dir() | |
| dataset_name = dataset_repo_id.split('/')[-1] # Extract dataset name | |
| local_dataset_dir = os.path.join(datasets_dir, dataset_name) | |
| # Check if already downloaded | |
| if os.path.exists(local_dataset_dir) and os.listdir(local_dataset_dir): | |
| print(f"Dataset {dataset_repo_id} already exists locally at {local_dataset_dir}") | |
| return local_dataset_dir | |
| # Create datasets directory if it doesn't exist | |
| os.makedirs(datasets_dir, exist_ok=True) | |
| print(f"Downloading dataset {dataset_repo_id} to {local_dataset_dir}...") | |
| downloaded_path = snapshot_download( | |
| repo_id=dataset_repo_id, | |
| repo_type="dataset", | |
| local_dir=local_dataset_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| print(f"Dataset downloaded successfully to {downloaded_path}") | |
| return downloaded_path | |
| except Exception as e: | |
| print(f"Error downloading dataset {dataset_repo_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def get_testimages_dirs(): | |
| """Get testimages directories dictionary - evaluated lazily to avoid startup issues""" | |
| return { | |
| 'flat_surface_detection': os.path.join(get_datasets_dir(), 'testmages__flatsurface'), | |
| 'fluid_purity_demo': os.path.join(get_datasets_dir(), 'testmages__milkpurity'), | |
| 'dinov3_demo': os.path.join(get_datasets_dir(), 'testmages_dino'), | |
| 'custom_yolov8_demo': os.path.join(get_datasets_dir(), 'testmages__yolov8'), | |
| 'spatiotemporal_detection': os.path.join(get_datasets_dir(), 'testmages_spatiotemporal'), | |
| 'detect_yolov3': os.path.join(get_datasets_dir(), 'testmages__yolov3'), | |
| 'material_detection_head': os.path.join(get_datasets_dir(), 'val_natural_material_detection') | |
| } | |
| # Keep for backward compatibility but use function instead | |
| TESTIMAGES_DIRS = get_testimages_dirs() | |
| def list_testimages(app_name, subpath=''): | |
| """List files and directories in the testimages directory for a specific app, with optional subdirectory navigation""" | |
| try: | |
| testimages_dirs = get_testimages_dirs() | |
| if app_name not in testimages_dirs: | |
| return jsonify({'error': f'Unknown app: {app_name}'}), 404 | |
| testimages_dir = testimages_dirs[app_name] | |
| # On Hugging Face Spaces, download dataset if it doesn't exist locally | |
| if os.environ.get("SPACE_ID") and not os.path.exists(testimages_dir): | |
| dataset_repo_id = DATASET_REPO_MAP.get(app_name) | |
| if dataset_repo_id: | |
| print(f"Dataset not found locally, attempting to download {dataset_repo_id}...") | |
| downloaded_path = download_dataset_from_hub(dataset_repo_id) | |
| if downloaded_path: | |
| # Update testimages_dir to use downloaded path | |
| testimages_dir = downloaded_path | |
| else: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Dataset not found and could not be downloaded. Please ensure {dataset_repo_id} exists on Hugging Face Hub.', | |
| 'repo_id': dataset_repo_id, | |
| 'files': [] | |
| }), 404 | |
| # Build the full path to the directory to list | |
| if subpath: | |
| # Join subpath to base directory | |
| current_dir = os.path.join(testimages_dir, subpath) | |
| # Normalize path to prevent directory traversal attacks | |
| current_dir = os.path.normpath(current_dir) | |
| testimages_dir = os.path.normpath(testimages_dir) | |
| # Security check: ensure we're still within the testimages directory | |
| if not current_dir.startswith(testimages_dir): | |
| return jsonify({'error': 'Invalid directory path'}), 403 | |
| else: | |
| current_dir = testimages_dir | |
| if not os.path.exists(current_dir): | |
| return jsonify({'error': f'Directory not found: {current_dir}'}), 404 | |
| if not os.path.isdir(current_dir): | |
| return jsonify({'error': f'Path is not a directory: {current_dir}'}), 400 | |
| # Get all files and directories | |
| files = [] | |
| directories = [] | |
| # For yolov3_custom, yolov8_custom, dinov3_custom, and spatiotemporal_detection, accept all image formats | |
| # For other apps, use the original allowed extensions | |
| if app_name in ['detect_yolov3', 'custom_yolov8_demo', 'dinov3_demo', 'spatiotemporal_detection']: | |
| # Accept all image formats and common file formats | |
| allowed_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.tiff', '.tif', | |
| '.heic', '.HEIC', '.PNG', '.JPG', '.JPEG', '.GIF', '.WEBP', '.BMP', '.TIFF', '.TIF', | |
| '.sto', '.STO'} | |
| else: | |
| # Original allowed extensions for other apps | |
| allowed_extensions = {'.png', '.jpg', '.jpeg', '.sto', '.heic', '.HEIC', '.PNG', '.JPG', '.JPEG'} | |
| for itemname in os.listdir(current_dir): | |
| itempath = os.path.join(current_dir, itemname) | |
| if os.path.isdir(itempath): | |
| # It's a directory | |
| dir_stat = os.stat(itempath) | |
| directories.append({ | |
| 'name': itemname, | |
| 'type': 'directory', | |
| 'modified': datetime.fromtimestamp(dir_stat.st_mtime).isoformat() | |
| }) | |
| elif os.path.isfile(itempath): | |
| # It's a file | |
| _, ext = os.path.splitext(itemname) | |
| # For yolov3_custom, yolov8_custom, dinov3_custom, and spatiotemporal_detection, accept any file format | |
| if app_name in ['detect_yolov3', 'custom_yolov8_demo', 'dinov3_demo', 'spatiotemporal_detection']: | |
| # Accept any file format | |
| file_stat = os.stat(itempath) | |
| files.append({ | |
| 'name': itemname, | |
| 'type': 'file', | |
| 'size': file_stat.st_size, | |
| 'modified': datetime.fromtimestamp(file_stat.st_mtime).isoformat() | |
| }) | |
| else: | |
| # For other apps, check extension | |
| if ext.lower() in [e.lower() for e in allowed_extensions]: | |
| file_stat = os.stat(itempath) | |
| files.append({ | |
| 'name': itemname, | |
| 'type': 'file', | |
| 'size': file_stat.st_size, | |
| 'modified': datetime.fromtimestamp(file_stat.st_mtime).isoformat() | |
| }) | |
| # Sort: directories first, then files, both alphabetically | |
| directories.sort(key=lambda x: x['name'].lower()) | |
| files.sort(key=lambda x: x['name'].lower()) | |
| # Combine directories and files (directories first) | |
| items = directories + files | |
| # Calculate relative path for display | |
| if subpath: | |
| relative_path = subpath | |
| else: | |
| relative_path = '' | |
| # Get dataset repo URL if available | |
| dataset_repo_id = DATASET_REPO_MAP.get(app_name) | |
| dataset_repo_url = f"https://huggingface.co/datasets/{dataset_repo_id}" if dataset_repo_id else None | |
| return jsonify({ | |
| 'success': True, | |
| 'app_name': app_name, | |
| 'directory': current_dir, | |
| 'relative_path': relative_path, | |
| 'items': items, | |
| 'directories': directories, | |
| 'files': files, # For backward compatibility | |
| 'count': len(items), | |
| 'repo_url': dataset_repo_url, | |
| 'repo_id': dataset_repo_id | |
| }) | |
| except Exception as e: | |
| print(f"ERROR listing testimages for {app_name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e), | |
| 'files': [] | |
| }), 500 | |
| def get_testimage(app_name, filename): | |
| """Serve a test image file from the testimages directory, supporting subdirectory paths""" | |
| try: | |
| testimages_dirs = get_testimages_dirs() | |
| if app_name not in testimages_dirs: | |
| return jsonify({'error': f'Unknown app: {app_name}'}), 404 | |
| testimages_dir = testimages_dirs[app_name] | |
| # filename can now include subdirectory paths (e.g., "subdir/image.jpg") | |
| filepath = os.path.join(testimages_dir, filename) | |
| # Security check: ensure file is within the testimages directory | |
| filepath = os.path.normpath(filepath) | |
| testimages_dir = os.path.normpath(testimages_dir) | |
| if not filepath.startswith(testimages_dir): | |
| return jsonify({'error': 'Invalid file path'}), 403 | |
| if not os.path.exists(filepath): | |
| return jsonify({'error': 'File not found'}), 404 | |
| if not os.path.isfile(filepath): | |
| return jsonify({'error': 'Path is not a file'}), 400 | |
| # Determine MIME type based on file extension | |
| _, ext = os.path.splitext(filename) | |
| mime_types = { | |
| '.jpg': 'image/jpeg', | |
| '.jpeg': 'image/jpeg', | |
| '.png': 'image/png', | |
| '.gif': 'image/gif', | |
| '.webp': 'image/webp', | |
| '.bmp': 'image/bmp', | |
| '.tiff': 'image/tiff', | |
| '.tif': 'image/tiff', | |
| '.heic': 'image/heic', | |
| '.HEIC': 'image/heic', | |
| '.sto': 'application/octet-stream', | |
| '.STO': 'application/octet-stream' | |
| } | |
| mime_type = mime_types.get(ext.lower(), 'application/octet-stream') | |
| return send_file(filepath, mimetype=mime_type) | |
| except Exception as e: | |
| print(f"ERROR getting testimage for {app_name}/{filename}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def kill_existing_processes_on_port(port): | |
| """Kill any existing processes running on the specified port (excluding current process). | |
| Only works on local systems, not on Hugging Face Spaces.""" | |
| # Skip on Hugging Face Spaces to avoid errors | |
| if os.environ.get("SPACE_ID"): | |
| return | |
| try: | |
| current_pid = str(os.getpid()) | |
| # Find processes using the port (works on macOS/Linux) | |
| result = subprocess.run( | |
| ['lsof', '-ti', f':{port}'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 # Add timeout to prevent hanging | |
| ) | |
| if result.stdout.strip(): | |
| pids = result.stdout.strip().split('\n') | |
| killed_any = False | |
| for pid in pids: | |
| if pid and pid != current_pid: | |
| try: | |
| subprocess.run(['kill', '-9', pid], check=False, timeout=2) | |
| print(f"Killed process {pid} on port {port}") | |
| killed_any = True | |
| except Exception as e: | |
| print(f"Error killing process {pid}: {e}") | |
| if not killed_any: | |
| print(f"No other processes found on port {port} (current PID: {current_pid})") | |
| else: | |
| print(f"No existing processes found on port {port}") | |
| except (FileNotFoundError, subprocess.TimeoutExpired): | |
| # lsof might not be available or command timed out, skip silently | |
| pass | |
| except Exception as e: | |
| # Log but don't fail - this is non-critical | |
| print(f"Warning: Could not check for existing processes on port {port}: {e}") | |
| def download_model_from_hub(hub_path): | |
| """Download a specific model file from Hugging Face Hub on demand | |
| Args: | |
| hub_path: Path in format 'hub://repo_id/path/to/file' or just local path | |
| Returns: | |
| Local file path if successful, None otherwise | |
| """ | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # If not a Hub path, return as-is | |
| if not hub_path.startswith('hub://'): | |
| return hub_path if os.path.exists(hub_path) else None | |
| # Parse Hub path: hub://repo_id/path/to/file | |
| # Repo ID can contain slashes (e.g., "mvplus/yolov3"), so we need to find the last part | |
| path_without_prefix = hub_path.replace('hub://', '') | |
| # Find the last occurrence of '/' to separate repo_id from file_path | |
| # The repo_id is everything before the last '/', file_path is everything after | |
| last_slash_idx = path_without_prefix.rfind('/') | |
| if last_slash_idx == -1: | |
| print(f"Invalid Hub path format (no file path): {hub_path}") | |
| return None | |
| repo_id = path_without_prefix[:last_slash_idx] | |
| file_path = path_without_prefix[last_slash_idx + 1:] | |
| print(f"DEBUG: Parsed Hub path - repo_id: {repo_id}, file_path: {file_path}") | |
| models_dir = get_models_dir() | |
| local_file_path = os.path.join(models_dir, file_path) | |
| # If file already exists locally, use it | |
| if os.path.exists(local_file_path): | |
| print(f"Using cached model: {local_file_path}") | |
| return local_file_path | |
| # Download from Hub | |
| print(f"Downloading {file_path} from {repo_id}...") | |
| local_dir = os.path.dirname(local_file_path) | |
| os.makedirs(local_dir, exist_ok=True) | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=file_path, | |
| repo_type="model", | |
| local_dir=models_dir, | |
| token=os.environ.get("HF_TOKEN"), | |
| resume_download=True | |
| ) | |
| print(f"✓ Downloaded to: {downloaded_path}") | |
| return downloaded_path | |
| except Exception as e: | |
| print(f"Error downloading model {hub_path}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| if __name__ == "__main__": | |
| import time | |
| startup_start = time.time() | |
| print("=" * 60) | |
| print("Starting SPAD for Vision application...") | |
| print("=" * 60) | |
| print("Note: Models will be downloaded on-demand when selected from dropdowns") | |
| print("=" * 60) | |
| # Use PORT from env on Hugging Face Spaces (HF sets PORT=7860); 7889 for local testing | |
| if os.environ.get("SPACE_ID"): | |
| port = int(os.environ.get("PORT", "7860")) | |
| else: | |
| port = int(os.environ.get("PORT", 7889)) | |
| print(f"Starting Flask server on port {port}...") | |
| print(f"SPACE_ID: {os.environ.get('SPACE_ID', 'Not set')}") | |
| print(f"PORT env var: {os.environ.get('PORT', 'Not set')}") | |
| # Kill any existing processes on the port before starting (for local testing only) | |
| # Skip on Hugging Face Spaces to avoid potential issues | |
| if os.environ.get("WERKZEUG_RUN_MAIN") != "true" and not os.environ.get("SPACE_ID"): | |
| try: | |
| kill_existing_processes_on_port(port) | |
| except Exception as e: | |
| print(f"Warning: Could not kill existing processes on port {port}: {e}") | |
| print("Continuing anyway...") | |
| startup_time = time.time() - startup_start | |
| print(f"Startup completed in {startup_time:.2f} seconds") | |
| print("=" * 60) | |
| # Disable debug mode on Hugging Face Spaces for better performance | |
| # Debug mode enables auto-reloader which can be slow and cause issues | |
| # Use threaded mode for better concurrency | |
| debug_mode = False # Disabled for production performance | |
| threaded = True # Enable threading for concurrent requests | |
| if os.environ.get("SPACE_ID"): | |
| print("Running on Hugging Face Spaces (production mode)") | |
| else: | |
| print("Running locally (testing mode)") | |
| # Only enable debug locally if explicitly requested | |
| debug_mode = os.environ.get("FLASK_DEBUG", "false").lower() == "true" | |
| if debug_mode: | |
| print("Debug mode: ENABLED") | |
| else: | |
| print("Debug mode: DISABLED (production mode)") | |
| app.run(host="0.0.0.0", port=port, debug=debug_mode, threaded=threaded) | |