Sanjay / test_yolo_detection.py
TheDeepDas's picture
Yolo
6bbbfda
#!/usr/bin/env python
"""
Test script for YOLO detection with version compatibility fix.
This script tests the YOLO object detection in isolation to verify
that the PyTorch/torchvision version compatibility fixes are working.
"""
import os
import sys
import logging
import tempfile
import uuid
from pathlib import Path
import numpy as np
import cv2
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("yolo_test")
# Test image directory and output directory
TEST_IMAGE_DIR = Path("test_files")
OUTPUT_DIR = Path("test_output/yolo_test")
def create_test_image(filename):
"""Create a synthetic test image with potential pollution patterns"""
# Create output directory if it doesn't exist
TEST_IMAGE_DIR.mkdir(exist_ok=True)
# Choose a pattern type
pattern_type = np.random.choice(["oil_spill", "plastic_debris", "foam"])
# Create base image (blue water)
img = np.zeros((400, 600, 3), dtype=np.uint8)
# Blue water background
img[:, :] = [220, 180, 90] # BGR format: water color
if pattern_type == "oil_spill":
# Create dark oil spill pattern
center_x = np.random.randint(200, 400)
center_y = np.random.randint(100, 300)
# Create irregular shape
for i in range(100):
x = center_x + np.random.randint(-80, 80)
y = center_y + np.random.randint(-50, 50)
r = np.random.randint(5, 30)
cv2.circle(img, (x, y), r, (40, 40, 40), -1) # Dark color
elif pattern_type == "plastic_debris":
# Create multiple plastic-like items
for _ in range(3):
x = np.random.randint(50, 550)
y = np.random.randint(50, 350)
w = np.random.randint(30, 60)
h = np.random.randint(20, 40)
# Random color for plastic (white, blue, green)
colors = [
(255, 255, 255), # White
(255, 200, 0), # Blue-ish
(0, 200, 200) # Green-ish
]
color = colors[np.random.randint(0, len(colors))]
cv2.rectangle(img, (x, y), (x+w, y+h), color, -1)
elif pattern_type == "foam":
# Create foam-like pattern
for _ in range(20):
x = np.random.randint(50, 550)
y = np.random.randint(50, 350)
r = np.random.randint(5, 20)
cv2.circle(img, (x, y), r, (255, 255, 255), -1) # White foam
# Save the image
output_path = TEST_IMAGE_DIR / filename
cv2.imwrite(str(output_path), img)
logger.info(f"Created test image '{pattern_type}' at {output_path}")
return output_path
def test_yolo_detection():
"""Test YOLO detection with our patched version fixes"""
try:
# Add the app directory to sys.path
parent_dir = Path(__file__).parent
sys.path.append(str(parent_dir))
# Import our detection module
from app.services.image_processing import detect_objects_in_image, download_image, initialize_yolo_model
import asyncio
# Create output directory
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# Use existing test images from test_files directory
test_images = []
# Check for existing test images
existing_images = list(TEST_IMAGE_DIR.glob('*.jpg')) + list(TEST_IMAGE_DIR.glob('*.jpeg')) + list(TEST_IMAGE_DIR.glob('*.png'))
if existing_images:
logger.info(f"Found {len(existing_images)} existing test images")
test_images = existing_images
logger.info(f"Using test images: {[str(img) for img in test_images]}")
else:
# Fallback to creating synthetic test images if no real images found
logger.info("No existing test images found, creating synthetic images")
# Download a set of test images for marine pollution detection
try:
import requests
TEST_IMAGE_DIR.mkdir(exist_ok=True)
# Test image URLs for different types of marine pollution
test_urls = [
# Plastic bottles
("https://www.condorferries.co.uk/media/2455/plastic-bottles-on-beach.jpg", "plastic_bottle_beach.jpg"),
("https://oceanservice.noaa.gov/hazards/marinedebris/entanglement-or-ingestion-can-kill.jpg", "plastic_waste_beach.jpg"),
# Plastic waste
("https://www.noaa.gov/sites/default/files/2021-03/Marine%20debris%20on%20a%20Hawaii%20beach%20NOAA.jpg", "beach_debris.jpg"),
# Oil spill
("https://media.istockphoto.com/id/177162311/photo/oil-spill-on-beach.jpg", "oil_spill.jpg"),
# Ship
("https://scx2.b-cdn.net/gfx/news/2018/shippingindi.jpg", "ship_water.jpg"),
]
# Download each test image
for url, filename in test_urls:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
file_path = TEST_IMAGE_DIR / filename
with open(file_path, "wb") as f:
f.write(response.content)
test_images.append(file_path)
logger.info(f"Downloaded test image to {file_path}")
except Exception as e:
logger.warning(f"Failed to download test image {filename}: {e}")
logger.info(f"Downloaded {len(test_images)} test images")
except Exception as e:
logger.warning(f"Failed to download test images: {e}")
# Create synthetic images as backup
if not test_images:
for i in range(3):
filename = f"test_pollution_{i}_{uuid.uuid4().hex[:8]}.jpg"
test_images.append(create_test_image(filename))
# Initialize YOLO model directly (test if it works)
logger.info("Initializing YOLO model directly...")
# Get module versions for debugging
try:
import torch
import torchvision
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"Torchvision version: {torchvision.__version__}")
except ImportError as e:
logger.warning(f"Could not import torch/torchvision: {e}")
# Process each test image
async def process_images():
results = []
for image_path in test_images:
try:
logger.info(f"Processing image: {image_path}")
# We need to provide a URL, but for testing we can just use the file path
# Our system expects to download from a URL, so we'll create a temporary URL-like structure
file_url = f"file://{os.path.abspath(image_path)}"
# Mock the download function for local files
original_download = download_image
async def mock_download_image(url):
if url.startswith("file://"):
# Local file
local_path = url[7:] # Remove file:// prefix
with open(local_path, "rb") as f:
return f.read()
else:
# Regular URL
return await original_download(url)
# Replace the download function temporarily
from app.services import image_processing
image_processing.download_image = mock_download_image
# Run detection
detection_results = await detect_objects_in_image(file_url)
# Restore original download function
image_processing.download_image = original_download
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
continue
if detection_results:
logger.info(f"Detection results: {detection_results}")
# Try to save annotated image locally if available
try:
if "annotated_image_url" in detection_results and detection_results["annotated_image_url"]:
annotated_url = detection_results["annotated_image_url"]
# Handle both remote and local URLs
if annotated_url.startswith("http"):
# Download the annotated image
import requests
resp = requests.get(annotated_url, stream=True)
if resp.status_code == 200:
output_file = OUTPUT_DIR / f"annotated_{image_path.name}"
with open(output_file, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Saved annotated image to {output_file}")
else:
# Local file - handle relative paths correctly
from shutil import copy
local_path = annotated_url
# Handle paths that start with /uploads or uploads
if local_path.startswith("/uploads/") or local_path.startswith("uploads/"):
# Local path from app directory
if local_path.startswith("/"):
local_path = local_path[1:] # Remove leading slash
# Construct absolute path relative to app directory
app_dir = Path(parent_dir) / "app"
absolute_path = app_dir / local_path
logger.info(f"Looking for local file at: {absolute_path}")
if absolute_path.exists():
output_file = OUTPUT_DIR / f"annotated_{image_path.name}"
copy(str(absolute_path), output_file)
logger.info(f"Copied annotated image to {output_file}")
else:
logger.warning(f"Cannot find local file: {absolute_path}")
# Try to save the detection results without the image
else:
output_file = OUTPUT_DIR / f"annotated_{image_path.name}"
copy(local_path, output_file)
logger.info(f"Copied annotated image to {output_file}")
except Exception as e:
logger.warning(f"Failed to save/copy annotated image: {e}")
# Save detection results
results.append({
"image": str(image_path),
"detection_count": detection_results.get("detection_count", 0),
"detections": detection_results.get("detections", []),
"method": detection_results.get("method", "yolo")
})
else:
logger.warning(f"No detection results for {image_path}")
return results
# Run the async test
results = asyncio.run(process_images())
# Print summary
logger.info("=== YOLO Detection Test Results ===")
for i, result in enumerate(results):
logger.info(f"Image {i+1}: {result['image']}")
logger.info(f" Detection count: {result['detection_count']}")
logger.info(f" Detection method: {result.get('method', 'yolo')}")
for det in result.get("detections", []):
logger.info(f" - {det.get('class')}: {det.get('confidence')}")
logger.info(f"Annotated images saved to: {OUTPUT_DIR}")
# Show which version combination was successful
if len(results) > 0:
logger.info("=== Successful Version Combination ===")
try:
import torch
import torchvision
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"Torchvision version: {torchvision.__version__}")
# Write to requirements-version-fix.txt
with open("requirements-version-fix.txt", "w") as f:
f.write(f"# Successfully tested on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"torch=={torch.__version__}\n")
f.write(f"torchvision=={torchvision.__version__}\n")
f.write("ultralytics\n")
f.write("opencv-python\n")
f.write("cloudinary\n")
f.write("numpy\n")
f.write("requests\n")
logger.info(f"Wrote successful versions to requirements-version-fix.txt")
except ImportError:
logger.warning("Could not determine torch/torchvision versions")
return True
except Exception as e:
logger.error(f"Error in YOLO detection test: {e}", exc_info=True)
return False
if __name__ == "__main__":
success = test_yolo_detection()
if success:
logger.info("YOLO detection test completed successfully")
else:
logger.error("YOLO detection test failed")