Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Test script for YOLO detection with version compatibility fix. | |
| This script tests the YOLO object detection in isolation to verify | |
| that the PyTorch/torchvision version compatibility fixes are working. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| import numpy as np | |
| import cv2 | |
| from datetime import datetime | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger("yolo_test") | |
| # Test image directory and output directory | |
| TEST_IMAGE_DIR = Path("test_files") | |
| OUTPUT_DIR = Path("test_output/yolo_test") | |
| def create_test_image(filename): | |
| """Create a synthetic test image with potential pollution patterns""" | |
| # Create output directory if it doesn't exist | |
| TEST_IMAGE_DIR.mkdir(exist_ok=True) | |
| # Choose a pattern type | |
| pattern_type = np.random.choice(["oil_spill", "plastic_debris", "foam"]) | |
| # Create base image (blue water) | |
| img = np.zeros((400, 600, 3), dtype=np.uint8) | |
| # Blue water background | |
| img[:, :] = [220, 180, 90] # BGR format: water color | |
| if pattern_type == "oil_spill": | |
| # Create dark oil spill pattern | |
| center_x = np.random.randint(200, 400) | |
| center_y = np.random.randint(100, 300) | |
| # Create irregular shape | |
| for i in range(100): | |
| x = center_x + np.random.randint(-80, 80) | |
| y = center_y + np.random.randint(-50, 50) | |
| r = np.random.randint(5, 30) | |
| cv2.circle(img, (x, y), r, (40, 40, 40), -1) # Dark color | |
| elif pattern_type == "plastic_debris": | |
| # Create multiple plastic-like items | |
| for _ in range(3): | |
| x = np.random.randint(50, 550) | |
| y = np.random.randint(50, 350) | |
| w = np.random.randint(30, 60) | |
| h = np.random.randint(20, 40) | |
| # Random color for plastic (white, blue, green) | |
| colors = [ | |
| (255, 255, 255), # White | |
| (255, 200, 0), # Blue-ish | |
| (0, 200, 200) # Green-ish | |
| ] | |
| color = colors[np.random.randint(0, len(colors))] | |
| cv2.rectangle(img, (x, y), (x+w, y+h), color, -1) | |
| elif pattern_type == "foam": | |
| # Create foam-like pattern | |
| for _ in range(20): | |
| x = np.random.randint(50, 550) | |
| y = np.random.randint(50, 350) | |
| r = np.random.randint(5, 20) | |
| cv2.circle(img, (x, y), r, (255, 255, 255), -1) # White foam | |
| # Save the image | |
| output_path = TEST_IMAGE_DIR / filename | |
| cv2.imwrite(str(output_path), img) | |
| logger.info(f"Created test image '{pattern_type}' at {output_path}") | |
| return output_path | |
| def test_yolo_detection(): | |
| """Test YOLO detection with our patched version fixes""" | |
| try: | |
| # Add the app directory to sys.path | |
| parent_dir = Path(__file__).parent | |
| sys.path.append(str(parent_dir)) | |
| # Import our detection module | |
| from app.services.image_processing import detect_objects_in_image, download_image, initialize_yolo_model | |
| import asyncio | |
| # Create output directory | |
| OUTPUT_DIR.mkdir(exist_ok=True, parents=True) | |
| # Use existing test images from test_files directory | |
| test_images = [] | |
| # Check for existing test images | |
| existing_images = list(TEST_IMAGE_DIR.glob('*.jpg')) + list(TEST_IMAGE_DIR.glob('*.jpeg')) + list(TEST_IMAGE_DIR.glob('*.png')) | |
| if existing_images: | |
| logger.info(f"Found {len(existing_images)} existing test images") | |
| test_images = existing_images | |
| logger.info(f"Using test images: {[str(img) for img in test_images]}") | |
| else: | |
| # Fallback to creating synthetic test images if no real images found | |
| logger.info("No existing test images found, creating synthetic images") | |
| # Download a set of test images for marine pollution detection | |
| try: | |
| import requests | |
| TEST_IMAGE_DIR.mkdir(exist_ok=True) | |
| # Test image URLs for different types of marine pollution | |
| test_urls = [ | |
| # Plastic bottles | |
| ("https://www.condorferries.co.uk/media/2455/plastic-bottles-on-beach.jpg", "plastic_bottle_beach.jpg"), | |
| ("https://oceanservice.noaa.gov/hazards/marinedebris/entanglement-or-ingestion-can-kill.jpg", "plastic_waste_beach.jpg"), | |
| # Plastic waste | |
| ("https://www.noaa.gov/sites/default/files/2021-03/Marine%20debris%20on%20a%20Hawaii%20beach%20NOAA.jpg", "beach_debris.jpg"), | |
| # Oil spill | |
| ("https://media.istockphoto.com/id/177162311/photo/oil-spill-on-beach.jpg", "oil_spill.jpg"), | |
| # Ship | |
| ("https://scx2.b-cdn.net/gfx/news/2018/shippingindi.jpg", "ship_water.jpg"), | |
| ] | |
| # Download each test image | |
| for url, filename in test_urls: | |
| try: | |
| response = requests.get(url, timeout=5) | |
| if response.status_code == 200: | |
| file_path = TEST_IMAGE_DIR / filename | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| test_images.append(file_path) | |
| logger.info(f"Downloaded test image to {file_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to download test image {filename}: {e}") | |
| logger.info(f"Downloaded {len(test_images)} test images") | |
| except Exception as e: | |
| logger.warning(f"Failed to download test images: {e}") | |
| # Create synthetic images as backup | |
| if not test_images: | |
| for i in range(3): | |
| filename = f"test_pollution_{i}_{uuid.uuid4().hex[:8]}.jpg" | |
| test_images.append(create_test_image(filename)) | |
| # Initialize YOLO model directly (test if it works) | |
| logger.info("Initializing YOLO model directly...") | |
| # Get module versions for debugging | |
| try: | |
| import torch | |
| import torchvision | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"Torchvision version: {torchvision.__version__}") | |
| except ImportError as e: | |
| logger.warning(f"Could not import torch/torchvision: {e}") | |
| # Process each test image | |
| async def process_images(): | |
| results = [] | |
| for image_path in test_images: | |
| try: | |
| logger.info(f"Processing image: {image_path}") | |
| # We need to provide a URL, but for testing we can just use the file path | |
| # Our system expects to download from a URL, so we'll create a temporary URL-like structure | |
| file_url = f"file://{os.path.abspath(image_path)}" | |
| # Mock the download function for local files | |
| original_download = download_image | |
| async def mock_download_image(url): | |
| if url.startswith("file://"): | |
| # Local file | |
| local_path = url[7:] # Remove file:// prefix | |
| with open(local_path, "rb") as f: | |
| return f.read() | |
| else: | |
| # Regular URL | |
| return await original_download(url) | |
| # Replace the download function temporarily | |
| from app.services import image_processing | |
| image_processing.download_image = mock_download_image | |
| # Run detection | |
| detection_results = await detect_objects_in_image(file_url) | |
| # Restore original download function | |
| image_processing.download_image = original_download | |
| except Exception as e: | |
| logger.error(f"Error processing image {image_path}: {e}") | |
| continue | |
| if detection_results: | |
| logger.info(f"Detection results: {detection_results}") | |
| # Try to save annotated image locally if available | |
| try: | |
| if "annotated_image_url" in detection_results and detection_results["annotated_image_url"]: | |
| annotated_url = detection_results["annotated_image_url"] | |
| # Handle both remote and local URLs | |
| if annotated_url.startswith("http"): | |
| # Download the annotated image | |
| import requests | |
| resp = requests.get(annotated_url, stream=True) | |
| if resp.status_code == 200: | |
| output_file = OUTPUT_DIR / f"annotated_{image_path.name}" | |
| with open(output_file, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info(f"Saved annotated image to {output_file}") | |
| else: | |
| # Local file - handle relative paths correctly | |
| from shutil import copy | |
| local_path = annotated_url | |
| # Handle paths that start with /uploads or uploads | |
| if local_path.startswith("/uploads/") or local_path.startswith("uploads/"): | |
| # Local path from app directory | |
| if local_path.startswith("/"): | |
| local_path = local_path[1:] # Remove leading slash | |
| # Construct absolute path relative to app directory | |
| app_dir = Path(parent_dir) / "app" | |
| absolute_path = app_dir / local_path | |
| logger.info(f"Looking for local file at: {absolute_path}") | |
| if absolute_path.exists(): | |
| output_file = OUTPUT_DIR / f"annotated_{image_path.name}" | |
| copy(str(absolute_path), output_file) | |
| logger.info(f"Copied annotated image to {output_file}") | |
| else: | |
| logger.warning(f"Cannot find local file: {absolute_path}") | |
| # Try to save the detection results without the image | |
| else: | |
| output_file = OUTPUT_DIR / f"annotated_{image_path.name}" | |
| copy(local_path, output_file) | |
| logger.info(f"Copied annotated image to {output_file}") | |
| except Exception as e: | |
| logger.warning(f"Failed to save/copy annotated image: {e}") | |
| # Save detection results | |
| results.append({ | |
| "image": str(image_path), | |
| "detection_count": detection_results.get("detection_count", 0), | |
| "detections": detection_results.get("detections", []), | |
| "method": detection_results.get("method", "yolo") | |
| }) | |
| else: | |
| logger.warning(f"No detection results for {image_path}") | |
| return results | |
| # Run the async test | |
| results = asyncio.run(process_images()) | |
| # Print summary | |
| logger.info("=== YOLO Detection Test Results ===") | |
| for i, result in enumerate(results): | |
| logger.info(f"Image {i+1}: {result['image']}") | |
| logger.info(f" Detection count: {result['detection_count']}") | |
| logger.info(f" Detection method: {result.get('method', 'yolo')}") | |
| for det in result.get("detections", []): | |
| logger.info(f" - {det.get('class')}: {det.get('confidence')}") | |
| logger.info(f"Annotated images saved to: {OUTPUT_DIR}") | |
| # Show which version combination was successful | |
| if len(results) > 0: | |
| logger.info("=== Successful Version Combination ===") | |
| try: | |
| import torch | |
| import torchvision | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"Torchvision version: {torchvision.__version__}") | |
| # Write to requirements-version-fix.txt | |
| with open("requirements-version-fix.txt", "w") as f: | |
| f.write(f"# Successfully tested on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| f.write(f"torch=={torch.__version__}\n") | |
| f.write(f"torchvision=={torchvision.__version__}\n") | |
| f.write("ultralytics\n") | |
| f.write("opencv-python\n") | |
| f.write("cloudinary\n") | |
| f.write("numpy\n") | |
| f.write("requests\n") | |
| logger.info(f"Wrote successful versions to requirements-version-fix.txt") | |
| except ImportError: | |
| logger.warning("Could not determine torch/torchvision versions") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error in YOLO detection test: {e}", exc_info=True) | |
| return False | |
| if __name__ == "__main__": | |
| success = test_yolo_detection() | |
| if success: | |
| logger.info("YOLO detection test completed successfully") | |
| else: | |
| logger.error("YOLO detection test failed") |