diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..c2f49f56d3cbbe977a841c2245820260cbc557f7 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+training_history/2019-12-19[[:space:]]01%3A53%3A15.480800.hdf5 filter=lfs diff=lfs merge=lfs -text
+training_history/2025-08-07_16-25-27.hdf5 filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..92b8ee37a84afe0071a8549ba8a90acc31884494
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1786 @@
+import glob
+import gradio as gr
+import matplotlib
+import numpy as np
+from PIL import Image
+import torch
+import tempfile
+from gradio_imageslider import ImageSlider
+import plotly.graph_objects as go
+import plotly.express as px
+import open3d as o3d
+from depth_anything_v2.dpt import DepthAnythingV2
+import os
+import tensorflow as tf
+from tensorflow.keras.models import load_model
+
+# Classification imports
+from transformers import AutoImageProcessor, AutoModelForImageClassification
+import google.generativeai as genai
+
+import gdown
+import spaces
+import cv2
+
+
+# Import actual segmentation model components
+from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.io.data import normalize
+
+# --- Classification Model Setup ---
+# Load classification model and processor
+classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification")
+classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification")
+
+# Configure Gemini AI
+try:
+ # Try to get API key from Hugging Face secrets
+ gemini_api_key = os.getenv("GOOGLE_API_KEY")
+ if not gemini_api_key:
+ raise ValueError("GEMINI_API_KEY not found in environment variables")
+
+ genai.configure(api_key=gemini_api_key)
+ gemini_model = genai.GenerativeModel("gemini-2.5-pro")
+ print("✅ Gemini AI configured successfully with API key from secrets")
+except Exception as e:
+ print(f"❌ Error configuring Gemini AI: {e}")
+ print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets")
+ gemini_model = None
+
+# --- Classification Functions ---
+def analyze_wound_with_gemini(image, predicted_label):
+ """
+ Analyze wound image using Gemini AI with classification context
+
+ Args:
+ image: PIL Image
+ predicted_label: The predicted wound type from classification model
+
+ Returns:
+ str: Gemini AI analysis
+ """
+ if image is None:
+ return "No image provided for analysis."
+
+ if gemini_model is None:
+ return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
+
+ try:
+ # Ensure image is in RGB format
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+
+ # Create prompt that includes the classification result
+ prompt = f"""You are assisting in a medical education and research task.
+
+Based on the wound classification model, this image has been identified as: {predicted_label}
+
+Please provide an educational analysis of this wound image focusing on:
+1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue)
+2. Educational explanation about this type of wound based on the classification: {predicted_label}
+3. General wound healing stages if applicable
+4. Key features that are typically associated with this wound type
+
+Important guidelines:
+- This is for educational and research purposes only
+- Do not provide medical advice or diagnosis
+- Keep the analysis objective and educational
+- Focus on visible features and general wound characteristics
+- Do not recommend treatments or medical interventions
+
+Please provide a comprehensive educational analysis."""
+
+ response = gemini_model.generate_content([prompt, image])
+ return response.text
+
+ except Exception as e:
+ return f"Error analyzing image with Gemini: {str(e)}"
+
+def analyze_wound_depth_with_gemini(image, depth_map, depth_stats):
+ """
+ Analyze wound depth and severity using Gemini AI with depth analysis context
+
+ Args:
+ image: Original wound image (PIL Image or numpy array)
+ depth_map: Depth map (numpy array)
+ depth_stats: Dictionary containing depth analysis statistics
+
+ Returns:
+ str: Gemini AI medical assessment based on depth analysis
+ """
+ if image is None or depth_map is None:
+ return "No image or depth map provided for analysis."
+
+ if gemini_model is None:
+ return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets."
+
+ try:
+ # Convert numpy array to PIL Image if needed
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ # Ensure image is in RGB format
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+
+ # Convert depth map to PIL Image for Gemini
+ if isinstance(depth_map, np.ndarray):
+ # Normalize depth map for visualization
+ norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
+ depth_image = Image.fromarray(norm_depth.astype(np.uint8))
+ else:
+ depth_image = depth_map
+
+ # Create detailed prompt with depth statistics
+ prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data.
+
+DEPTH ANALYSIS DATA PROVIDED:
+- Total Wound Area: {depth_stats['total_area_cm2']:.2f} cm²
+- Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm
+- Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm
+- Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm
+- Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cm³
+- Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}%
+- Analysis Quality: {depth_stats['analysis_quality']}
+- Depth Consistency: {depth_stats['depth_consistency']}
+
+TISSUE DEPTH DISTRIBUTION:
+- Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cm²
+- Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cm²
+- Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cm²
+- Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cm²
+
+STATISTICAL DEPTH ANALYSIS:
+- 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm
+- Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm
+- 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm
+
+Please provide a comprehensive medical assessment focusing on:
+
+1. **WOUND CHARACTERISTICS ANALYSIS**
+ - Visible wound features from the original image
+ - Correlation between visual appearance and depth measurements
+ - Tissue quality assessment based on color, texture, and depth data
+
+2. **DEPTH-BASED SEVERITY ASSESSMENT**
+ - Clinical significance of the measured depths
+ - Tissue layer involvement based on depth measurements
+ - Risk assessment based on deep tissue involvement percentage
+
+3. **HEALING PROGNOSIS**
+ - Expected healing timeline based on depth and area measurements
+ - Factors that may affect healing based on depth distribution
+ - Complexity assessment based on wound volume and depth variation
+
+4. **CLINICAL CONSIDERATIONS**
+ - Significance of depth consistency/inconsistency
+ - Areas of particular concern based on depth analysis
+ - Educational insights about this type of wound presentation
+
+5. **MEASUREMENT INTERPRETATION**
+ - Clinical relevance of the statistical depth measurements
+ - What the depth distribution tells us about wound progression
+ - Comparison to typical wound depth classifications
+
+IMPORTANT GUIDELINES:
+- This is for educational and research purposes only
+- Do not provide specific medical advice or treatment recommendations
+- Focus on objective analysis of the provided measurements
+- Correlate visual findings with quantitative depth data
+- Maintain educational and clinical terminology
+- Emphasize the relationship between depth measurements and clinical significance
+
+Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis."""
+
+ # Send both images to Gemini for analysis
+ response = gemini_model.generate_content([prompt, image, depth_image])
+ return response.text
+
+ except Exception as e:
+ return f"Error analyzing wound with Gemini AI: {str(e)}"
+
+def classify_wound(image):
+ """
+ Classify wound type from uploaded image
+
+ Args:
+ image: PIL Image or numpy array
+
+ Returns:
+ dict: Classification results with confidence scores
+ """
+ if image is None:
+ return "Please upload an image"
+
+ # Convert to PIL Image if needed
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ # Ensure image is in RGB format
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+
+ try:
+ # Process the image
+ inputs = classification_processor(images=image, return_tensors="pt")
+
+ # Get model predictions
+ with torch.no_grad():
+ outputs = classification_model(**inputs)
+ predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1)
+
+ # Get the predicted class labels and confidence scores
+ confidence_scores = predictions.numpy()
+
+ # Create results dictionary
+ results = {}
+ for i, score in enumerate(confidence_scores):
+ # Get class name from model config
+ class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}"
+ results[class_name] = float(score)
+
+ return results
+
+ except Exception as e:
+ return f"Error processing image: {str(e)}"
+
+def classify_and_analyze_wound(image):
+ """
+ Combined function to classify wound and get Gemini analysis
+
+ Args:
+ image: PIL Image or numpy array
+
+ Returns:
+ tuple: (classification_results, gemini_analysis)
+ """
+ if image is None:
+ return "Please upload an image", "Please upload an image for analysis"
+
+ # Get classification results
+ classification_results = classify_wound(image)
+
+ # Get the top predicted label for Gemini analysis
+ if isinstance(classification_results, dict) and classification_results:
+ # Get the label with highest confidence
+ top_label = max(classification_results.items(), key=lambda x: x[1])[0]
+
+ # Get Gemini analysis
+ gemini_analysis = analyze_wound_with_gemini(image, top_label)
+ else:
+ top_label = "Unknown"
+ gemini_analysis = "Unable to analyze due to classification error"
+
+ return classification_results, gemini_analysis
+
+def format_gemini_analysis(analysis):
+ """Format Gemini analysis as properly structured HTML"""
+ if not analysis or "Error" in analysis:
+ return f"""
+
+
Analysis Error
+
{analysis}
+
+ """
+
+ # Parse the markdown-style response and convert to HTML
+ formatted_analysis = parse_markdown_to_html(analysis)
+
+ return f"""
+
+
+ Initial Wound Analysis
+
+
+ {formatted_analysis}
+
+
+ """
+
+def format_gemini_depth_analysis(analysis):
+ """Format Gemini depth analysis as properly structured HTML for medical assessment"""
+ if not analysis or "Error" in analysis:
+ return f"""
+
+
+ ❌ AI Analysis Error
+
+
+ {analysis}
+
+
+ """
+
+ # Parse the markdown-style response and convert to HTML
+ formatted_analysis = parse_markdown_to_html(analysis)
+
+ return f"""
+
+
+ 🤖 AI-Powered Medical Assessment
+
+
+ {formatted_analysis}
+
+
+ """
+
+def parse_markdown_to_html(text):
+ """Convert markdown-style text to HTML"""
+ import re
+
+ # Replace markdown headers
+ text = re.sub(r'^### \*\*(.*?)\*\*$', r'\1
', text, flags=re.MULTILINE)
+ text = re.sub(r'^#### \*\*(.*?)\*\*$', r'\1
', text, flags=re.MULTILINE)
+ text = re.sub(r'^### (.*?)$', r'\1
', text, flags=re.MULTILINE)
+ text = re.sub(r'^#### (.*?)$', r'\1
', text, flags=re.MULTILINE)
+
+ # Replace bold text
+ text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
+
+ # Replace italic text
+ text = re.sub(r'\*(.*?)\*', r'\1', text)
+
+ # Replace bullet points
+ text = re.sub(r'^\* (.*?)$', r'\1', text, flags=re.MULTILINE)
+ text = re.sub(r'^ \* (.*?)$', r'\1', text, flags=re.MULTILINE)
+
+ # Wrap consecutive list items in ul tags
+ text = re.sub(r'((?:\s*)*)', r'', text, flags=re.DOTALL)
+
+ # Replace numbered lists
+ text = re.sub(r'^(\d+)\.\s+(.*?)$', r'\1. \2
', text, flags=re.MULTILINE)
+
+ # Convert paragraphs (double newlines)
+ paragraphs = text.split('\n\n')
+ formatted_paragraphs = []
+
+ for para in paragraphs:
+ para = para.strip()
+ if para:
+ # Skip if it's already wrapped in HTML tags
+ if not (para.startswith('<') or para.endswith('>')):
+ para = f'{para}
'
+ formatted_paragraphs.append(para)
+
+ return '\n'.join(formatted_paragraphs)
+
+def combined_analysis(image):
+ """Combined function for UI that returns both outputs"""
+ classification, gemini_analysis = classify_and_analyze_wound(image)
+ formatted_analysis = format_gemini_analysis(gemini_analysis)
+ return classification, formatted_analysis
+
+
+
+
+
+# Define path and file ID
+checkpoint_dir = "checkpoints"
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
+gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
+
+# Download if not already present
+if not os.path.exists(model_file):
+ print("Downloading model from Google Drive...")
+ gdown.download(gdrive_url, model_file, quiet=False)
+
+# --- TensorFlow: Check GPU Availability ---
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+ print("TensorFlow is using GPU")
+else:
+ print("TensorFlow is using CPU")
+
+
+
+# --- Load Actual Wound Segmentation Model ---
+class WoundSegmentationModel:
+ def __init__(self):
+ self.input_dim_x = 224
+ self.input_dim_y = 224
+ self.model = None
+ self.load_model()
+
+ def load_model(self):
+ """Load the trained wound segmentation model"""
+ try:
+ # Try to load the most recent model
+ weight_file_name = '2025-08-07_16-25-27.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e:
+ print(f"Error loading segmentation model: {e}")
+ # Fallback to the older model
+ try:
+ weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e2:
+ print(f"Error loading fallback segmentation model: {e2}")
+ self.model = None
+
+ def preprocess_image(self, image):
+ """Preprocess the uploaded image for model input"""
+ if image is None:
+ return None
+
+ # Convert to RGB if needed
+ if len(image.shape) == 3 and image.shape[2] == 3:
+ # Convert BGR to RGB if needed
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Resize to model input size
+ image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
+
+ # Normalize the image
+ image = image.astype(np.float32) / 255.0
+
+ # Add batch dimension
+ image = np.expand_dims(image, axis=0)
+
+ return image
+
+ def postprocess_prediction(self, prediction):
+ """Postprocess the model prediction"""
+ # Remove batch dimension
+ prediction = prediction[0]
+
+ # Apply threshold to get binary mask
+ threshold = 0.5
+ binary_mask = (prediction > threshold).astype(np.uint8) * 255
+
+ return binary_mask
+
+ def segment_wound(self, input_image):
+ """Main function to segment wound from uploaded image"""
+ if self.model is None:
+ return None, "Error: Segmentation model not loaded. Please check the model files."
+
+ if input_image is None:
+ return None, "Please upload an image."
+
+ try:
+ # Preprocess the image
+ processed_image = self.preprocess_image(input_image)
+
+ if processed_image is None:
+ return None, "Error processing image."
+
+ # Make prediction
+ prediction = self.model.predict(processed_image, verbose=0)
+
+ # Postprocess the prediction
+ segmented_mask = self.postprocess_prediction(prediction)
+
+ return segmented_mask, "Segmentation completed successfully!"
+
+ except Exception as e:
+ return None, f"Error during segmentation: {str(e)}"
+
+# Initialize the segmentation model
+segmentation_model = WoundSegmentationModel()
+
+# --- PyTorch: Set Device and Load Depth Model ---
+map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
+print(f"Using PyTorch device: {map_device}")
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+}
+encoder = 'vitl'
+depth_model = DepthAnythingV2(**model_configs[encoder])
+state_dict = torch.load(
+ f'checkpoints/depth_anything_v2_{encoder}.pth',
+ map_location=map_device
+)
+depth_model.load_state_dict(state_dict)
+depth_model = depth_model.to(map_device).eval()
+
+
+# --- Custom CSS for unified dark theme ---
+css = """
+.gradio-container {
+ font-family: 'Segoe UI', sans-serif;
+ background-color: #121212;
+ color: #ffffff;
+ padding: 20px;
+}
+.gr-button {
+ background-color: #2c3e50;
+ color: white;
+ border-radius: 10px;
+}
+.gr-button:hover {
+ background-color: #34495e;
+}
+.gr-html, .gr-html div {
+ white-space: normal !important;
+ overflow: visible !important;
+ text-overflow: unset !important;
+ word-break: break-word !important;
+}
+#img-display-container {
+ max-height: 100vh;
+}
+#img-display-input {
+ max-height: 80vh;
+}
+#img-display-output {
+ max-height: 80vh;
+}
+#download {
+ height: 62px;
+}
+h1 {
+ text-align: center;
+ font-size: 3rem;
+ font-weight: bold;
+ margin: 2rem 0;
+ color: #ffffff;
+}
+h2 {
+ color: #ffffff;
+ text-align: center;
+ margin: 1rem 0;
+}
+.gr-tabs {
+ background-color: #1e1e1e;
+ border-radius: 10px;
+ padding: 10px;
+}
+.gr-tab-nav {
+ background-color: #2c3e50;
+ border-radius: 8px;
+}
+.gr-tab-nav button {
+ color: #ffffff !important;
+}
+.gr-tab-nav button.selected {
+ background-color: #34495e !important;
+}
+/* Card styling for consistent heights */
+.wound-card {
+ min-height: 200px !important;
+ display: flex !important;
+ flex-direction: column !important;
+ justify-content: space-between !important;
+}
+.wound-card-content {
+ flex-grow: 1 !important;
+ display: flex !important;
+ flex-direction: column !important;
+ justify-content: center !important;
+}
+/* Loading animation */
+.loading-spinner {
+ display: inline-block;
+ width: 20px;
+ height: 20px;
+ border: 3px solid #f3f3f3;
+ border-top: 3px solid #3498db;
+ border-radius: 50%;
+ animation: spin 1s linear infinite;
+}
+@keyframes spin {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+}
+"""
+
+
+
+
+
+# --- Enhanced Wound Severity Estimation Functions ---
+
+def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
+ """
+ Enhanced depth analysis with proper calibration and medical standards
+ Based on wound depth classification standards:
+ - Superficial: 0-2mm (epidermis only)
+ - Partial thickness: 2-4mm (epidermis + partial dermis)
+ - Full thickness: 4-6mm (epidermis + full dermis)
+ - Deep: >6mm (involving subcutaneous tissue)
+ """
+ # Convert pixel spacing to mm
+ pixel_spacing_mm = float(pixel_spacing_mm)
+
+ # Calculate pixel area in cm²
+ pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
+
+ # Extract wound region (binary mask)
+ wound_mask = (mask > 127).astype(np.uint8)
+
+ # Apply morphological operations to clean the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
+ wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel)
+
+ # Get depth values only for wound region
+ wound_depths = depth_map[wound_mask > 0]
+
+ if len(wound_depths) == 0:
+ return {
+ 'total_area_cm2': 0,
+ 'superficial_area_cm2': 0,
+ 'partial_thickness_area_cm2': 0,
+ 'full_thickness_area_cm2': 0,
+ 'deep_area_cm2': 0,
+ 'mean_depth_mm': 0,
+ 'max_depth_mm': 0,
+ 'depth_std_mm': 0,
+ 'deep_ratio': 0,
+ 'wound_volume_cm3': 0,
+ 'depth_percentiles': {'25': 0, '50': 0, '75': 0}
+ }
+
+ # Normalize depth relative to nearest point in wound area
+ normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask)
+
+ # Calibrate the normalized depth map for more accurate measurements
+ calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm)
+
+ # Get calibrated depth values for wound region
+ wound_depths_mm = calibrated_depth_map[wound_mask > 0]
+
+ # Medical depth classification
+ superficial_mask = wound_depths_mm < 2.0
+ partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0)
+ full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0)
+ deep_mask = wound_depths_mm >= 6.0
+
+ # Calculate areas
+ total_pixels = np.sum(wound_mask > 0)
+ total_area_cm2 = total_pixels * pixel_area_cm2
+
+ superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2
+ partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2
+ full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2
+ deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2
+
+ # Calculate depth statistics
+ mean_depth_mm = np.mean(wound_depths_mm)
+ max_depth_mm = np.max(wound_depths_mm)
+ depth_std_mm = np.std(wound_depths_mm)
+
+ # Calculate depth percentiles
+ depth_percentiles = {
+ '25': np.percentile(wound_depths_mm, 25),
+ '50': np.percentile(wound_depths_mm, 50),
+ '75': np.percentile(wound_depths_mm, 75)
+ }
+
+ # Calculate depth distribution statistics
+ depth_distribution = {
+ 'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
+ 'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0,
+ 'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0
+ }
+
+ # Calculate wound volume (approximate)
+ # Volume = area * average depth
+ wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0)
+
+ # Deep tissue ratio
+ deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0
+
+ # Calculate analysis quality metrics
+ wound_pixel_count = len(wound_depths_mm)
+ analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low"
+
+ # Calculate depth consistency (lower std dev = more consistent)
+ depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low"
+
+ return {
+ 'total_area_cm2': total_area_cm2,
+ 'superficial_area_cm2': superficial_area_cm2,
+ 'partial_thickness_area_cm2': partial_thickness_area_cm2,
+ 'full_thickness_area_cm2': full_thickness_area_cm2,
+ 'deep_area_cm2': deep_area_cm2,
+ 'mean_depth_mm': mean_depth_mm,
+ 'max_depth_mm': max_depth_mm,
+ 'depth_std_mm': depth_std_mm,
+ 'deep_ratio': deep_ratio,
+ 'wound_volume_cm3': wound_volume_cm3,
+ 'depth_percentiles': depth_percentiles,
+ 'depth_distribution': depth_distribution,
+ 'analysis_quality': analysis_quality,
+ 'depth_consistency': depth_consistency,
+ 'wound_pixel_count': wound_pixel_count,
+ 'nearest_point_coords': nearest_point_coords,
+ 'max_relative_depth': max_relative_depth,
+ 'normalized_depth_map': normalized_depth_map
+ }
+
+def classify_wound_severity_by_enhanced_metrics(depth_stats):
+ """
+ Enhanced wound severity classification based on medical standards
+ Uses multiple criteria: depth, area, volume, and tissue involvement
+ """
+ if depth_stats['total_area_cm2'] == 0:
+ return "Unknown"
+
+ # Extract key metrics
+ total_area = depth_stats['total_area_cm2']
+ deep_area = depth_stats['deep_area_cm2']
+ full_thickness_area = depth_stats['full_thickness_area_cm2']
+ mean_depth = depth_stats['mean_depth_mm']
+ max_depth = depth_stats['max_depth_mm']
+ wound_volume = depth_stats['wound_volume_cm3']
+ deep_ratio = depth_stats['deep_ratio']
+
+ # Medical severity classification criteria
+ severity_score = 0
+
+ # Criterion 1: Maximum depth
+ if max_depth >= 10.0:
+ severity_score += 3 # Very severe
+ elif max_depth >= 6.0:
+ severity_score += 2 # Severe
+ elif max_depth >= 4.0:
+ severity_score += 1 # Moderate
+
+ # Criterion 2: Mean depth
+ if mean_depth >= 5.0:
+ severity_score += 2
+ elif mean_depth >= 3.0:
+ severity_score += 1
+
+ # Criterion 3: Deep tissue involvement ratio
+ if deep_ratio >= 0.5:
+ severity_score += 3 # More than 50% deep tissue
+ elif deep_ratio >= 0.25:
+ severity_score += 2 # 25-50% deep tissue
+ elif deep_ratio >= 0.1:
+ severity_score += 1 # 10-25% deep tissue
+
+ # Criterion 4: Total wound area
+ if total_area >= 10.0:
+ severity_score += 2 # Large wound (>10 cm²)
+ elif total_area >= 5.0:
+ severity_score += 1 # Medium wound (5-10 cm²)
+
+ # Criterion 5: Wound volume
+ if wound_volume >= 5.0:
+ severity_score += 2 # High volume
+ elif wound_volume >= 2.0:
+ severity_score += 1 # Medium volume
+
+ # Determine severity based on total score
+ if severity_score >= 8:
+ return "Very Severe"
+ elif severity_score >= 6:
+ return "Severe"
+ elif severity_score >= 4:
+ return "Moderate"
+ elif severity_score >= 2:
+ return "Mild"
+ else:
+ return "Superficial"
+
+
+
+
+
+def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
+ """Enhanced wound severity analysis based on depth measurements"""
+ if image is None or depth_map is None or wound_mask is None:
+ return "❌ Please upload image, depth map, and wound mask."
+
+ # Convert wound mask to grayscale if needed
+ if len(wound_mask.shape) == 3:
+ wound_mask = np.mean(wound_mask, axis=2)
+
+ # Ensure depth map and mask have same dimensions
+ if depth_map.shape[:2] != wound_mask.shape[:2]:
+ # Resize mask to match depth map
+ from PIL import Image
+ mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
+ mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
+ wound_mask = np.array(mask_pil)
+
+ # Compute enhanced statistics with relative depth normalization
+ stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm)
+
+ # Get severity based on enhanced metrics
+ severity_level = classify_wound_severity_by_enhanced_metrics(stats)
+ severity_description = get_enhanced_severity_description(severity_level)
+
+ # Get Gemini AI analysis based on depth data
+ gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats)
+
+ # Format Gemini analysis for display
+ formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis)
+
+ # Create depth analysis visualization
+ depth_visualization = create_depth_analysis_visualization(
+ stats['normalized_depth_map'], wound_mask,
+ stats['nearest_point_coords'], stats['max_relative_depth']
+ )
+
+ # Enhanced severity color coding
+ severity_color = {
+ "Superficial": "#4CAF50", # Green
+ "Mild": "#8BC34A", # Light Green
+ "Moderate": "#FF9800", # Orange
+ "Severe": "#F44336", # Red
+ "Very Severe": "#9C27B0" # Purple
+ }.get(severity_level, "#9E9E9E") # Gray for unknown
+
+ # Create comprehensive medical report
+ report = f"""
+
+
+ 🩹 Enhanced Wound Severity Analysis
+
+
+
+
+ 📊 Depth & Quality Analysis
+
+
+
+
� Basic Measurements
+
�📏 Mean Relative Depth: {stats['mean_depth_mm']:.1f} mm
+
📐 Max Relative Depth: {stats['max_depth_mm']:.1f} mm
+
📊 Depth Std Dev: {stats['depth_std_mm']:.1f} mm
+
📦 Wound Volume: {stats['wound_volume_cm3']:.2f} cm³
+
🔥 Deep Tissue Ratio: {stats['deep_ratio']*100:.1f}%
+
+
+
📈 Statistical Analysis
+
� 25th Percentile: {stats['depth_percentiles']['25']:.1f} mm
+
📊 Median (50th): {stats['depth_percentiles']['50']:.1f} mm
+
📊 75th Percentile: {stats['depth_percentiles']['75']:.1f} mm
+
📊 Shallow Areas: {stats['depth_distribution']['shallow_ratio']*100:.1f}%
+
📊 Moderate Areas: {stats['depth_distribution']['moderate_ratio']*100:.1f}%
+
+
+
🔍 Quality Metrics
+
🔍 Analysis Quality: {stats['analysis_quality']}
+
📏 Depth Consistency: {stats['depth_consistency']}
+
📊 Data Points: {stats['wound_pixel_count']:,}
+
📊 Deep Areas: {stats['depth_distribution']['deep_ratio']*100:.1f}%
+
🎯 Reference Point: Nearest to camera
+
+
+
+
+
+
+ 📊 Medical Assessment Based on Depth Analysis
+
+ {formatted_gemini_analysis}
+
+
+ """
+
+ return report
+
+def normalize_depth_relative_to_nearest_point(depth_map, wound_mask):
+ """
+ Normalize depth map relative to the nearest point in the wound area
+ This assumes a top-down camera perspective where the closest point to camera = 0 depth
+
+ Args:
+ depth_map: Raw depth map
+ wound_mask: Binary mask of wound region
+
+ Returns:
+ normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper)
+ nearest_point_coords: Coordinates of the nearest point
+ max_relative_depth: Maximum relative depth in the wound
+ """
+ if depth_map is None or wound_mask is None:
+ return depth_map, None, 0
+
+ # Convert mask to binary
+ binary_mask = (wound_mask > 127).astype(np.uint8)
+
+ # Find wound region coordinates
+ wound_coords = np.where(binary_mask > 0)
+
+ if len(wound_coords[0]) == 0:
+ return depth_map, None, 0
+
+ # Get depth values only for wound region
+ wound_depths = depth_map[wound_coords]
+
+ # Find the nearest point (minimum depth value in wound region)
+ nearest_depth = np.min(wound_depths)
+ nearest_indices = np.where(wound_depths == nearest_depth)
+
+ # Get coordinates of the nearest point(s)
+ nearest_point_coords = (wound_coords[0][nearest_indices[0][0]],
+ wound_coords[1][nearest_indices[0][0]])
+
+ # Create normalized depth map (relative to nearest point)
+ normalized_depth = depth_map.copy()
+ normalized_depth = normalized_depth - nearest_depth
+
+ # Ensure all values are non-negative (nearest point = 0, others = positive)
+ normalized_depth = np.maximum(normalized_depth, 0)
+
+ # Calculate maximum relative depth in wound region
+ wound_normalized_depths = normalized_depth[wound_coords]
+ max_relative_depth = np.max(wound_normalized_depths)
+
+ return normalized_depth, nearest_point_coords, max_relative_depth
+
+def calibrate_depth_map(depth_map, reference_depth_mm=10.0):
+ """
+ Calibrate depth map to real-world measurements using reference depth
+ This helps convert normalized depth values to actual millimeters
+ """
+ if depth_map is None:
+ return depth_map
+
+ # Find the maximum depth value in the depth map
+ max_depth_value = np.max(depth_map)
+ min_depth_value = np.min(depth_map)
+
+ if max_depth_value == min_depth_value:
+ return depth_map
+
+ # Apply calibration to convert to millimeters
+ # Assuming the maximum depth in the map corresponds to reference_depth_mm
+ calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm
+
+ return calibrated_depth
+
+def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth):
+ """
+ Create a visualization showing the depth analysis with nearest point and deepest point highlighted
+ """
+ if depth_map is None or wound_mask is None:
+ return None
+
+ # Create a copy of the depth map for visualization
+ vis_depth = depth_map.copy()
+
+ # Apply colormap for better visualization
+ normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth))
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ # Convert to RGB if grayscale
+ if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1:
+ colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB)
+
+ # Highlight the nearest point (reference point) with a red circle
+ if nearest_point_coords is not None:
+ y, x = nearest_point_coords
+ cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point
+ cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
+
+ # Find and highlight the deepest point
+ binary_mask = (wound_mask > 127).astype(np.uint8)
+ wound_coords = np.where(binary_mask > 0)
+
+ if len(wound_coords[0]) > 0:
+ # Get depth values for wound region
+ wound_depths = vis_depth[wound_coords]
+ max_depth_idx = np.argmax(wound_depths)
+ deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx])
+
+ # Highlight the deepest point with a blue circle
+ y, x = deepest_point_coords
+ cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point
+ cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
+
+ # Overlay wound mask outline
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary
+
+ return colored_depth
+
+def get_enhanced_severity_description(severity):
+ """Get comprehensive medical description for severity level"""
+ descriptions = {
+ "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.",
+ "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.",
+ "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.",
+ "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.",
+ "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.",
+ "Unknown": "Unable to determine severity due to insufficient data or poor image quality."
+ }
+ return descriptions.get(severity, "Severity assessment unavailable.")
+
+def create_sample_wound_mask(image_shape, center=None, radius=50):
+ """Create a sample circular wound mask for testing"""
+ if center is None:
+ center = (image_shape[1] // 2, image_shape[0] // 2)
+
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ y, x = np.ogrid[:image_shape[0], :image_shape[1]]
+
+ # Create circular mask
+ dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
+ mask[dist_from_center <= radius] = 255
+
+ return mask
+
+def create_realistic_wound_mask(image_shape, method='elliptical'):
+ """Create a more realistic wound mask with irregular shapes"""
+ h, w = image_shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+
+ if method == 'elliptical':
+ # Create elliptical wound mask
+ center = (w // 2, h // 2)
+ radius_x = min(w, h) // 3
+ radius_y = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ # Add some irregularity to make it more realistic
+ ellipse = ((x - center[0])**2 / (radius_x**2) +
+ (y - center[1])**2 / (radius_y**2)) <= 1
+
+ # Add some noise and irregularity
+ noise = np.random.random((h, w)) > 0.8
+ mask = (ellipse | noise).astype(np.uint8) * 255
+
+ elif method == 'irregular':
+ # Create irregular wound mask
+ center = (w // 2, h // 2)
+ radius = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
+
+ # Add irregular extensions
+ extensions = np.zeros_like(base_circle)
+ for i in range(3):
+ angle = i * 2 * np.pi / 3
+ ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
+ ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
+ ext_radius = radius // 3
+
+ ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
+ extensions = extensions | ext_circle
+
+ mask = (base_circle | extensions).astype(np.uint8) * 255
+
+ # Apply morphological operations to smooth the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ return mask
+
+# --- Depth Estimation Functions ---
+
+def predict_depth(image):
+ return depth_model.infer_image(image)
+
+def calculate_max_points(image):
+ """Calculate maximum points based on image dimensions (3x pixel count)"""
+ if image is None:
+ return 10000 # Default value
+ h, w = image.shape[:2]
+ max_points = h * w * 3
+ # Ensure minimum and reasonable maximum values
+ return max(1000, min(max_points, 300000))
+
+def update_slider_on_image_upload(image):
+ """Update the points slider when an image is uploaded"""
+ max_points = calculate_max_points(image)
+ default_value = min(10000, max_points // 10) # 10% of max points as default
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
+ label=f"Number of 3D points (max: {max_points:,})")
+
+
+def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
+ """Create a point cloud from depth map using camera intrinsics with high detail"""
+ h, w = depth_map.shape
+
+ # Use smaller step for higher detail (reduced downsampling)
+ step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ x_cam = (x_coords - w / 2) / focal_length_x
+ y_cam = (y_coords - h / 2) / focal_length_y
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors = image_colors.reshape(-1, 3) / 255.0
+
+ # Create Open3D point cloud
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+
+def reconstruct_surface_mesh_from_point_cloud(pcd):
+ """Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
+ # Estimate and orient normals with high precision
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
+ pcd.orient_normals_consistent_tangent_plane(k=50)
+
+ # Create surface mesh with maximum detail (depth=12 for very high resolution)
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
+
+ # Return mesh without filtering low-density vertices
+ return mesh
+
+
+def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
+ """Create an enhanced 3D visualization using proper camera projection"""
+ h, w = depth_map.shape
+
+ # Downsample to avoid too many points for performance
+ step = max(1, int(np.sqrt(h * w / max_points)))
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ focal_length = 470.4 # Default focal length
+ x_cam = (x_coords - w / 2) / focal_length
+ y_cam = (y_coords - h / 2) / focal_length
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ x_flat = x_3d.flatten()
+ y_flat = y_3d.flatten()
+ z_flat = z_3d.flatten()
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors_flat = image_colors.reshape(-1, 3)
+
+ # Create 3D scatter plot with proper camera projection
+ fig = go.Figure(data=[go.Scatter3d(
+ x=x_flat,
+ y=y_flat,
+ z=z_flat,
+ mode='markers',
+ marker=dict(
+ size=1.5,
+ color=colors_flat,
+ opacity=0.9
+ ),
+ hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
+ 'Depth: %{z:.2f}
' +
+ ''
+ )])
+
+ fig.update_layout(
+ title="3D Point Cloud Visualization (Camera Projection)",
+ scene=dict(
+ xaxis_title="X (meters)",
+ yaxis_title="Y (meters)",
+ zaxis_title="Z (meters)",
+ camera=dict(
+ eye=dict(x=2.0, y=2.0, z=2.0),
+ center=dict(x=0, y=0, z=0),
+ up=dict(x=0, y=0, z=1)
+ ),
+ aspectmode='data'
+ ),
+ width=700,
+ height=600
+ )
+
+ return fig
+
+def on_depth_submit(image, num_points, focal_x, focal_y):
+ original_image = image.copy()
+
+ h, w = image.shape[:2]
+
+ # Predict depth using the model
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+
+ # Save raw 16-bit depth
+ raw_depth = Image.fromarray(depth.astype('uint16'))
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ raw_depth.save(tmp_raw_depth.name)
+
+ # Normalize and convert to grayscale for display
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ norm_depth = norm_depth.astype(np.uint8)
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ gray_depth = Image.fromarray(norm_depth)
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ gray_depth.save(tmp_gray_depth.name)
+
+ # Create point cloud
+ pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
+
+ # Reconstruct mesh from point cloud
+ mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
+
+ # Save mesh with faces as .ply
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
+ o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
+
+ # Create enhanced 3D scatter plot visualization
+ depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
+
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
+
+# --- Actual Wound Segmentation Functions ---
+def create_automatic_wound_mask(image, method='deep_learning'):
+ """
+ Automatically generate wound mask from image using the actual deep learning model
+
+ Args:
+ image: Input image (numpy array)
+ method: Segmentation method (currently only 'deep_learning' supported)
+
+ Returns:
+ mask: Binary wound mask
+ """
+ if image is None:
+ return None
+
+ # Use the actual deep learning model for segmentation
+ if method == 'deep_learning':
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+ else:
+ # Fallback to deep learning if method not recognized
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+
+def post_process_wound_mask(mask, min_area=100):
+ """Post-process the wound mask to remove noise and small objects"""
+ if mask is None:
+ return None
+
+ # Convert to binary if needed
+ if mask.dtype != np.uint8:
+ mask = mask.astype(np.uint8)
+
+ # Apply morphological operations to clean up
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Remove small objects using OpenCV
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area >= min_area:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ # Fill holes
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
+
+ return mask_clean
+
+def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
+ """Analyze wound severity with automatic mask generation using actual segmentation model"""
+ if image is None or depth_map is None:
+ return "❌ Please provide both image and depth map."
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
+
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
+
+# --- Main Gradio Interface ---
+with gr.Blocks(css=css, title="Wound Analysis System") as demo:
+ gr.HTML("Wound Analysis System
")
+ #gr.Markdown("### Complete workflow: Classification → Depth Estimation → Wound Severity Analysis")
+
+ # Shared states
+ shared_image = gr.State()
+ shared_depth_map = gr.State()
+
+ with gr.Tabs():
+
+ # Tab 1: Wound Classification
+ with gr.Tab("1. 🔍 Wound Classification & Initial Analysis"):
+ gr.Markdown("### Step 1: Classify wound type and get initial AI analysis")
+ #gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.")
+
+
+ with gr.Row():
+ # Left Column - Image Upload
+ with gr.Column(scale=1):
+ gr.HTML('Upload Wound Image
')
+ classification_image_input = gr.Image(
+ label="",
+ type="pil",
+ height=400
+ )
+ # Place Clear and Analyse buttons side by side
+ with gr.Row():
+ classify_clear_btn = gr.Button(
+ "Clear",
+ variant="secondary",
+ size="lg",
+ scale=1
+ )
+ analyse_btn = gr.Button(
+ "Analyse",
+ variant="primary",
+ size="lg",
+ scale=1
+ )
+ # Right Column - Classification Results
+ with gr.Column(scale=1):
+ gr.HTML('Classification Results
')
+ classification_output = gr.Label(
+ label="",
+ num_top_classes=5,
+ show_label=False
+ )
+
+ # Second Row - Full Width AI Analysis
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.HTML('Wound Visual Analysis
')
+ gemini_output = gr.HTML(
+ value="""
+
+ Upload an image to get AI-powered wound analysis
+
+ """
+ )
+
+ # Event handlers for classification tab
+ classify_clear_btn.click(
+ fn=lambda: (None, None, """
+
+ Upload an image to get AI-powered wound analysis
+
+ """),
+ inputs=None,
+ outputs=[classification_image_input, classification_output, gemini_output]
+ )
+
+ # Only run classification on image upload
+ def classify_and_store(image):
+ result = classify_wound(image)
+ return result
+
+ classification_image_input.change(
+ fn=classify_and_store,
+ inputs=classification_image_input,
+ outputs=classification_output
+ )
+
+ # Store image in shared state for next tabs
+ def store_shared_image(image):
+ return image
+
+ classification_image_input.change(
+ fn=store_shared_image,
+ inputs=classification_image_input,
+ outputs=shared_image
+ )
+
+ # Run Gemini analysis only when Analyse button is clicked
+ def run_gemini_on_click(image, classification):
+ # Get top label
+ if isinstance(classification, dict) and classification:
+ top_label = max(classification.items(), key=lambda x: x[1])[0]
+ else:
+ top_label = "Unknown"
+ gemini_analysis = analyze_wound_with_gemini(image, top_label)
+ formatted_analysis = format_gemini_analysis(gemini_analysis)
+ return formatted_analysis
+
+ analyse_btn.click(
+ fn=run_gemini_on_click,
+ inputs=[classification_image_input, classification_output],
+ outputs=gemini_output
+ )
+
+ # Tab 2: Depth Estimation
+ with gr.Tab("2. 📏 Depth Estimation & 3D Visualization"):
+ gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
+ gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
+
+ with gr.Row():
+ load_from_classification_btn = gr.Button("🔄 Load Image from Classification Tab", variant="secondary")
+
+ with gr.Row():
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
+
+ with gr.Row():
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
+
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
+ label="Number of 3D points (upload image to update max)")
+
+ with gr.Row():
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length X (pixels)")
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length Y (pixels)")
+
+ # Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right
+ with gr.Row():
+ with gr.Column(scale=2):
+ # 3D Visualization
+ gr.Markdown("### 3D Point Cloud Visualization")
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
+
+ with gr.Column(scale=1):
+ gr.Markdown("### Download Files")
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
+
+
+
+ # Tab 3: Wound Severity Analysis
+ with gr.Tab("3. 🩹 Wound Severity Analysis"):
+ gr.Markdown("### Step 3: Analyze wound severity using depth maps")
+ gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
+
+ with gr.Row():
+ # Load depth map from previous tab
+ load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary")
+
+ with gr.Row():
+ severity_input_image = gr.Image(label="Original Image", type='numpy')
+ severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
+
+ with gr.Row():
+ wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
+
+ with gr.Row():
+ severity_output = gr.HTML(
+ label="🤖 AI-Powered Medical Assessment",
+ value="""
+
+
+ 🩹 Wound Severity Analysis
+
+
+ ⏳ Waiting for Input...
+
+
+ Please upload an image and depth map, then click "🤖 Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment.
+
+
+ """
+ )
+
+ gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
+
+ with gr.Row():
+ auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
+ pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
+ label="Pixel Spacing (mm/pixel)")
+ depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0,
+ label="Depth Calibration (mm)",
+ info="Adjust based on expected maximum wound depth")
+
+ #gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
+ #gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.")
+
+ #gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
+
+ # Update slider when image is uploaded
+ depth_input_image.change(
+ fn=update_slider_on_image_upload,
+ inputs=[depth_input_image],
+ outputs=[points_slider]
+ )
+
+ # Modified depth submit function to store depth map
+ def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
+ results = on_depth_submit(image, num_points, focal_x, focal_y)
+ # Extract depth map from results for severity analysis
+ depth_map = None
+ if image is not None:
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+ # Normalize depth for severity analysis
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth_map = norm_depth.astype(np.uint8)
+ return results + [depth_map]
+
+ depth_submit.click(on_depth_submit_with_state,
+ inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
+ outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map])
+
+ # Function to load image from classification to depth tab
+ def load_image_from_classification(shared_img):
+ if shared_img is None:
+ return None, "❌ No image available from classification tab. Please upload an image in Tab 1 first."
+
+ # Convert PIL image to numpy array for depth estimation
+ if hasattr(shared_img, 'convert'):
+ # It's a PIL image, convert to numpy
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab successfully!"
+ else:
+ # Already numpy array
+ return shared_img, "✅ Image loaded from classification tab successfully!"
+
+ # Connect the load button
+ load_from_classification_btn.click(
+ fn=load_image_from_classification,
+ inputs=shared_image,
+ outputs=[depth_input_image, gr.HTML()]
+ )
+
+ # Load depth map to severity tab and auto-generate mask
+ def load_depth_to_severity(depth_map, original_image):
+ if depth_map is None:
+ return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
+
+ # Auto-generate wound mask using segmentation model
+ if original_image is not None:
+ auto_mask, _ = segmentation_model.segment_wound(original_image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!"
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded successfully!"
+
+ load_depth_btn.click(
+ fn=load_depth_to_severity,
+ inputs=[shared_depth_map, depth_input_image],
+ outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
+ )
+
+ # Loading state function
+ def show_loading_state():
+ return """
+
+
+ 🩹 Wound Severity Analysis
+
+
+ 🔄 AI Analysis in Progress...
+
+
+ • Generating wound mask with deep learning model
+ • Computing depth measurements and statistics
+ • Analyzing wound characteristics with Gemini AI
+ • Preparing comprehensive medical assessment
+
+
+
+
+ """
+
+ # Automatic severity analysis function
+ def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration):
+ if depth_map is None:
+ return """
+
+
+ ❌ Error
+
+
+ Please load depth map from Tab 1 first.
+
+
+ """
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method='deep_learning')
+
+ if auto_mask is None:
+ return """
+
+
+ ❌ Error
+
+
+ Failed to generate automatic wound mask. Please check if the segmentation model is loaded.
+
+
+ """
+
+ # Post-process the mask with fixed minimum area
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return """
+
+
+ ⚠️ No Wound Detected
+
+
+ No wound region detected by the segmentation model. Try uploading a different image or use manual mask.
+
+
+ """
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration)
+
+ # Connect event handler with loading state
+ auto_severity_button.click(
+ fn=show_loading_state,
+ inputs=[],
+ outputs=[severity_output]
+ ).then(
+ fn=run_auto_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider],
+ outputs=[severity_output]
+ )
+
+
+
+ # Auto-generate mask when image is uploaded
+ def auto_generate_mask_on_image_upload(image):
+ if image is None:
+ return None, "❌ No image uploaded."
+
+ # Generate automatic wound mask using segmentation model
+ auto_mask, _ = segmentation_model.segment_wound(image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return processed_mask, "✅ Wound mask auto-generated using deep learning model!"
+ else:
+ return None, "✅ Image uploaded but no wound detected. Try uploading a different image."
+ else:
+ return None, "✅ Image uploaded but segmentation failed. Try uploading a different image."
+
+ # Load shared image from classification tab
+ def load_shared_image(shared_img):
+ if shared_img is None:
+ return gr.Image(), "❌ No image available from classification tab"
+
+ # Convert PIL image to numpy array for depth estimation
+ if hasattr(shared_img, 'convert'):
+ # It's a PIL image, convert to numpy
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab"
+ else:
+ # Already numpy array
+ return shared_img, "✅ Image loaded from classification tab"
+
+ # Auto-generate mask when image is uploaded to severity tab
+ severity_input_image.change(
+ fn=auto_generate_mask_on_image_upload,
+ inputs=[severity_input_image],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+
+
+if __name__ == '__main__':
+ demo.queue().launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True
+ )
\ No newline at end of file
diff --git a/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc b/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..320c66e3854de2ab05a7ec6787958fc9a7bdef84
Binary files /dev/null and b/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc differ
diff --git a/depth_anything_v2/__pycache__/dpt.cpython-310.pyc b/depth_anything_v2/__pycache__/dpt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca627d373cf39e8a744e3c402fa0498348cf9d49
Binary files /dev/null and b/depth_anything_v2/__pycache__/dpt.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2.py b/depth_anything_v2/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cbfc7d24d37796d5310fd966b582bb3773685dc
--- /dev/null
+++ b/depth_anything_v2/dinov2.py
@@ -0,0 +1,415 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+ # w0, h0 = w0 + 0.1, h0 + 0.1
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
+ mode="bicubic",
+ antialias=self.interpolate_antialias
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def DINOv2(model_name):
+ model_zoo = {
+ "vits": vit_small,
+ "vitb": vit_base,
+ "vitl": vit_large,
+ "vitg": vit_giant2
+ }
+
+ return model_zoo[model_name](
+ img_size=518,
+ patch_size=14,
+ init_values=1.0,
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
+ block_chunks=0,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1
+ )
diff --git a/depth_anything_v2/dinov2_layers/__init__.py b/depth_anything_v2/dinov2_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e59a83eb90512d763b03e4d38536b6ae07e87541
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0efe36e1b6abcdeff4f7611f0cb096ba8606f6f4
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34d1812bf2e59d6d5796ee57b16bea7364413731
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..322578e8683fdb1a35bfc8e23af92a87e4546b02
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3bd2ebcf2ead79e30a75a3ecd9b974d92310083
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52af677deee30b6e536f6369647598cc1c089ec9
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ea50613125836d15598205f316241a06e12e727
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b92b1b42a848aeb9b26363664e670e99c1593c8
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afbad238bb2fadccf537cb42e2b77cab95f01482
Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ
diff --git a/depth_anything_v2/dinov2_layers/attention.py b/depth_anything_v2/dinov2_layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea0c82d55f052bf4bcb5896ad8c37158ef523d5
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
\ No newline at end of file
diff --git a/depth_anything_v2/dinov2_layers/block.py b/depth_anything_v2/dinov2_layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91f3f07bd15fba91c67068c8dce2bb22d505bf7
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/depth_anything_v2/dinov2_layers/drop_path.py b/depth_anything_v2/dinov2_layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..10c3bea8e40eec258bbe59087770d230a6375481
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/depth_anything_v2/dinov2_layers/layer_scale.py b/depth_anything_v2/dinov2_layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a4d0eedb1dc974a45e06fbe77ff3d909e36e55
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/depth_anything_v2/dinov2_layers/mlp.py b/depth_anything_v2/dinov2_layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..504987b635c9cd582a352fb2381228c9e6cd043c
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/depth_anything_v2/dinov2_layers/patch_embed.py b/depth_anything_v2/dinov2_layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..f880c042ee6a33ef520c6a8c8a686c1d065b8f49
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/depth_anything_v2/dinov2_layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..155a3dd9f6f1a7d0f7bdf9c8f1981e58acb3b19c
--- /dev/null
+++ b/depth_anything_v2/dinov2_layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/depth_anything_v2/dpt.py b/depth_anything_v2/dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..acef20bfcf80318709dcf6c5e8c19b117394a06b
--- /dev/null
+++ b/depth_anything_v2/dpt.py
@@ -0,0 +1,221 @@
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from .dinov2 import DINOv2
+from .util.blocks import FeatureFusionBlock, _make_scratch
+from .util.transform import Resize, NormalizeImage, PrepareForNet
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True),
+ nn.Identity(),
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
+
+class DepthAnythingV2(nn.Module):
+ def __init__(
+ self,
+ encoder='vitl',
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ use_bn=False,
+ use_clstoken=False
+ ):
+ super(DepthAnythingV2, self).__init__()
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitb': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23],
+ 'vitg': [9, 19, 29, 39]
+ }
+
+ self.encoder = encoder
+ self.pretrained = DINOv2(model_name=encoder)
+
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
+
+ def forward(self, x):
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
+
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
+
+ depth = self.depth_head(features, patch_h, patch_w)
+ depth = F.relu(depth)
+
+ return depth.squeeze(1)
+
+ @torch.no_grad()
+ def infer_image(self, raw_image, input_size=518):
+ image, (h, w) = self.image2tensor(raw_image, input_size)
+
+ depth = self.forward(image)
+
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
+
+ return depth.cpu().numpy()
+
+ def image2tensor(self, raw_image, input_size=518):
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ h, w = raw_image.shape[:2]
+
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
+
+ image = transform({'image': image})['image']
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
+ image = image.to(DEVICE)
+
+ return image, (h, w)
diff --git a/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc b/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..580ca2d9a11546d5bafd4dadc9600e91c7bf3b11
Binary files /dev/null and b/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc differ
diff --git a/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc b/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d265f0d3c4dcb77e24f86929a3fc602e5e30ffcc
Binary files /dev/null and b/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc differ
diff --git a/depth_anything_v2/util/blocks.py b/depth_anything_v2/util/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fb66c03702d653f411c59ab9966916c348c7c6e
--- /dev/null
+++ b/depth_anything_v2/util/blocks.py
@@ -0,0 +1,148 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/depth_anything_v2/util/transform.py b/depth_anything_v2/util/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce234c86177e1ad5c84c81c7c1afb16877c9da
--- /dev/null
+++ b/depth_anything_v2/util/transform.py
@@ -0,0 +1,158 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/models/FCN.py b/models/FCN.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c69feca49bd6919e5ead7f57f60db06cee826cc
--- /dev/null
+++ b/models/FCN.py
@@ -0,0 +1,55 @@
+import os
+from keras.models import Model
+from keras.layers import Input
+from keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D
+from utils.BilinearUpSampling import BilinearUpSampling2D
+
+
+def FCN_Vgg16_16s(input_shape=None, weight_decay=0., batch_momentum=0.9, batch_shape=None, classes=1):
+ if batch_shape:
+ img_input = Input(batch_shape=batch_shape)
+ image_size = batch_shape[1:3]
+ else:
+ img_input = Input(shape=input_shape)
+ image_size = input_shape[0:2]
+ # Block 1
+ x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer='l2')(img_input)
+ x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer='l2')(x)
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
+
+ # Block 2
+ x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer='l2')(x)
+ x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer='l2')(x)
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
+
+ # Block 3
+ x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer='l2')(x)
+ x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer='l2')(x)
+ x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer='l2')(x)
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
+
+ # Block 4
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer='l2')(x)
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer='l2')(x)
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer='l2')(x)
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
+
+ # Block 5
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer='l2')(x)
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer='l2')(x)
+ x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer='l2')(x)
+
+ # Convolutional layers transfered from fully-connected layers
+ x = Conv2D(4096, (7, 7), activation='relu', padding='same', dilation_rate=(2, 2),
+ name='fc1', kernel_regularizer='l2')(x)
+ x = Dropout(0.5)(x)
+ x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer='l2')(x)
+ x = Dropout(0.5)(x)
+ #classifying layer
+ x = Conv2D(classes, (1, 1), kernel_initializer='he_normal', activation='linear', padding='valid', strides=(1, 1), kernel_regularizer='l2')(x)
+
+ x = BilinearUpSampling2D(size=(16, 16))(x)
+
+ model = Model(img_input, x)
+ model_name = 'FCN_Vgg16_16'
+ return model, model_name
\ No newline at end of file
diff --git a/models/SegNet.py b/models/SegNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..650585c3185ec3494977ebe790973711650296f3
--- /dev/null
+++ b/models/SegNet.py
@@ -0,0 +1,33 @@
+from keras.models import Model
+from keras.layers import Input
+from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D
+
+
+class SegNet:
+ def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels):
+ self.input_dim_x = input_dim_x
+ self.input_dim_y = input_dim_y
+ self.n_filters = n_filters
+ self.num_channels = num_channels
+
+ def get_SegNet(self):
+ convnet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
+
+ encoder_conv1 = Conv2D(self.n_filters, kernel_size=9, activation='relu', padding='same')(convnet_input)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(encoder_conv1)
+ encoder_conv2 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool1)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(encoder_conv2)
+ encoder_conv3 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool2)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(encoder_conv3)
+ encoder_conv4 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool3)
+ pool4 = MaxPooling2D(pool_size=(2, 2))(encoder_conv4)
+
+ conv5 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool4)
+
+ decoder_conv6 = Conv2D(self.n_filters, kernel_size=7, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
+ decoder_conv7 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv6))
+ decoder_conv8 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv7))
+ #decoder_conv9 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8))
+ decoder_conv9 = Conv2D(1, kernel_size=1, activation='sigmoid', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8))
+
+ return Model(outputs=decoder_conv9, inputs=convnet_input), 'SegNet'
\ No newline at end of file
diff --git a/models/__pycache__/FCN.cpython-37.pyc b/models/__pycache__/FCN.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8efdb47a204f521261dadb9743dcca856b522a9f
Binary files /dev/null and b/models/__pycache__/FCN.cpython-37.pyc differ
diff --git a/models/__pycache__/FCN.cpython-39.pyc b/models/__pycache__/FCN.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..735f22cedf94b409b7347b6f9fc09ad89a5d7348
Binary files /dev/null and b/models/__pycache__/FCN.cpython-39.pyc differ
diff --git a/models/__pycache__/SegNet.cpython-37.pyc b/models/__pycache__/SegNet.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ebd187de4eb22db3c0d18a9dc966ec1d0880a94
Binary files /dev/null and b/models/__pycache__/SegNet.cpython-37.pyc differ
diff --git a/models/__pycache__/SegNet.cpython-39.pyc b/models/__pycache__/SegNet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7816b13e59765cce0b3cc6c4c53720c932ba5ce
Binary files /dev/null and b/models/__pycache__/SegNet.cpython-39.pyc differ
diff --git a/models/__pycache__/deeplab.cpython-310.pyc b/models/__pycache__/deeplab.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad99672800379db699d7dec0668711e21fa14e29
Binary files /dev/null and b/models/__pycache__/deeplab.cpython-310.pyc differ
diff --git a/models/__pycache__/deeplab.cpython-313.pyc b/models/__pycache__/deeplab.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a71ee9a65fd1f578047531d9aea9a4e03c8269a3
Binary files /dev/null and b/models/__pycache__/deeplab.cpython-313.pyc differ
diff --git a/models/__pycache__/deeplab.cpython-37.pyc b/models/__pycache__/deeplab.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..889940f7b30762d822089c0c8023f0b636235d2e
Binary files /dev/null and b/models/__pycache__/deeplab.cpython-37.pyc differ
diff --git a/models/__pycache__/deeplab.cpython-39.pyc b/models/__pycache__/deeplab.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84d996fdf13f9c184ec09947f490e02d3f9315f8
Binary files /dev/null and b/models/__pycache__/deeplab.cpython-39.pyc differ
diff --git a/models/__pycache__/unets.cpython-37.pyc b/models/__pycache__/unets.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6876f7142f039468b23019771dbcfd9a48874762
Binary files /dev/null and b/models/__pycache__/unets.cpython-37.pyc differ
diff --git a/models/__pycache__/unets.cpython-39.pyc b/models/__pycache__/unets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e2d40a7e10376169b03e12094acd95c25a5a700
Binary files /dev/null and b/models/__pycache__/unets.cpython-39.pyc differ
diff --git a/models/deeplab.py b/models/deeplab.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc53345d50c8fedfd693b326258dba2a7e7ebbd
--- /dev/null
+++ b/models/deeplab.py
@@ -0,0 +1,539 @@
+# -*- coding: utf-8 -*-
+
+""" Deeplabv3+ model for Keras.
+This model is based on this repo:
+https://github.com/bonlime/keras-deeplab-v3-plus
+
+MobileNetv2 backbone is based on this repo:
+https://github.com/JonathanCMitchell/mobilenet_v2_keras
+
+# Reference
+- [Encoder-Decoder with Atrous Separable Convolution
+ for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
+- [Xception: Deep Learning with Depthwise Separable Convolutions]
+ (https://arxiv.org/abs/1610.02357)
+- [Inverted Residuals and Linear Bottlenecks: Mobile Networks for
+ Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from keras.models import Model
+from keras import layers
+from keras.layers import Input
+from keras.layers import Activation
+from keras.layers import Concatenate
+from keras.layers import Add
+from keras.layers import Dropout
+from keras.layers import BatchNormalization
+from keras.layers import Conv2D
+from keras.layers import DepthwiseConv2D
+from keras.layers import ZeroPadding2D
+from keras.layers import AveragePooling2D
+from keras.layers import Layer
+from tensorflow.keras.layers import InputSpec
+from tensorflow.keras.utils import get_source_inputs
+from keras import backend as K
+from keras.applications import imagenet_utils
+from keras.utils import conv_utils
+from keras.utils.data_utils import get_file
+
+WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5"
+WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5"
+WEIGHTS_PATH_X_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5"
+WEIGHTS_PATH_MOBILE_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5"
+
+class BilinearUpsampling(Layer):
+ """Just a simple bilinear upsampling layer. Works only with TF.
+ Args:
+ upsampling: tuple of 2 numbers > 0. The upsampling ratio for h and w
+ output_size: used instead of upsampling arg if passed!
+ """
+
+ def __init__(self, upsampling=(2, 2), output_size=None, data_format=None, **kwargs):
+
+ super(BilinearUpsampling, self).__init__(**kwargs)
+
+ self.data_format = K.image_data_format()
+ self.input_spec = InputSpec(ndim=4)
+ if output_size:
+ self.output_size = conv_utils.normalize_tuple(
+ output_size, 2, 'output_size')
+ self.upsampling = None
+ else:
+ self.output_size = None
+ self.upsampling = conv_utils.normalize_tuple(
+ upsampling, 2, 'upsampling')
+
+ def compute_output_shape(self, input_shape):
+ if self.upsampling:
+ height = self.upsampling[0] * \
+ input_shape[1] if input_shape[1] is not None else None
+ width = self.upsampling[1] * \
+ input_shape[2] if input_shape[2] is not None else None
+ else:
+ height = self.output_size[0]
+ width = self.output_size[1]
+ return (input_shape[0],
+ height,
+ width,
+ input_shape[3])
+
+ def call(self, inputs):
+ if self.upsampling:
+ return tf.compat.v1.image.resize_bilinear(inputs, (inputs.shape[1] * self.upsampling[0],
+ inputs.shape[2] * self.upsampling[1]),
+ align_corners=True)
+ else:
+ return tf.compat.v1.image.resize_bilinear(inputs, (self.output_size[0],
+ self.output_size[1]),
+ align_corners=True)
+
+ def get_config(self):
+ config = {'upsampling': self.upsampling,
+ 'output_size': self.output_size,
+ 'data_format': self.data_format}
+ base_config = super(BilinearUpsampling, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+def SepConv_BN(x, filters, prefix, stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3):
+ """ SepConv with BN between depthwise & pointwise. Optionally add activation after BN
+ Implements right "same" padding for even kernel sizes
+ Args:
+ x: input tensor
+ filters: num of filters in pointwise convolution
+ prefix: prefix before name
+ stride: stride at depthwise conv
+ kernel_size: kernel size for depthwise convolution
+ rate: atrous rate for depthwise convolution
+ depth_activation: flag to use activation between depthwise & poinwise convs
+ epsilon: epsilon to use in BN layer
+ """
+
+ if stride == 1:
+ depth_padding = 'same'
+ else:
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
+ pad_total = kernel_size_effective - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ x = ZeroPadding2D((pad_beg, pad_end))(x)
+ depth_padding = 'valid'
+
+ if not depth_activation:
+ x = Activation('relu')(x)
+ x = DepthwiseConv2D((kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate),
+ padding=depth_padding, use_bias=False, name=prefix + '_depthwise')(x)
+ x = BatchNormalization(name=prefix + '_depthwise_BN', epsilon=epsilon)(x)
+ if depth_activation:
+ x = Activation('relu')(x)
+ x = Conv2D(filters, (1, 1), padding='same',
+ use_bias=False, name=prefix + '_pointwise')(x)
+ x = BatchNormalization(name=prefix + '_pointwise_BN', epsilon=epsilon)(x)
+ if depth_activation:
+ x = Activation('relu')(x)
+
+ return x
+
+
+def _conv2d_same(x, filters, prefix, stride=1, kernel_size=3, rate=1):
+ """Implements right 'same' padding for even kernel sizes
+ Without this there is a 1 pixel drift when stride = 2
+ Args:
+ x: input tensor
+ filters: num of filters in pointwise convolution
+ prefix: prefix before name
+ stride: stride at depthwise conv
+ kernel_size: kernel size for depthwise convolution
+ rate: atrous rate for depthwise convolution
+ """
+ if stride == 1:
+ return Conv2D(filters,
+ (kernel_size, kernel_size),
+ strides=(stride, stride),
+ padding='same', use_bias=False,
+ dilation_rate=(rate, rate),
+ name=prefix)(x)
+ else:
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
+ pad_total = kernel_size_effective - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ x = ZeroPadding2D((pad_beg, pad_end))(x)
+ return Conv2D(filters,
+ (kernel_size, kernel_size),
+ strides=(stride, stride),
+ padding='valid', use_bias=False,
+ dilation_rate=(rate, rate),
+ name=prefix)(x)
+
+
+def _xception_block(inputs, depth_list, prefix, skip_connection_type, stride,
+ rate=1, depth_activation=False, return_skip=False):
+ """ Basic building block of modified Xception network
+ Args:
+ inputs: input tensor
+ depth_list: number of filters in each SepConv layer. len(depth_list) == 3
+ prefix: prefix before name
+ skip_connection_type: one of {'conv','sum','none'}
+ stride: stride at last depthwise conv
+ rate: atrous rate for depthwise convolution
+ depth_activation: flag to use activation between depthwise & pointwise convs
+ return_skip: flag to return additional tensor after 2 SepConvs for decoder
+ """
+ residual = inputs
+ for i in range(3):
+ residual = SepConv_BN(residual,
+ depth_list[i],
+ prefix + '_separable_conv{}'.format(i + 1),
+ stride=stride if i == 2 else 1,
+ rate=rate,
+ depth_activation=depth_activation)
+ if i == 1:
+ skip = residual
+ if skip_connection_type == 'conv':
+ shortcut = _conv2d_same(inputs, depth_list[-1], prefix + '_shortcut',
+ kernel_size=1,
+ stride=stride)
+ shortcut = BatchNormalization(name=prefix + '_shortcut_BN')(shortcut)
+ outputs = layers.add([residual, shortcut])
+ elif skip_connection_type == 'sum':
+ outputs = layers.add([residual, inputs])
+ elif skip_connection_type == 'none':
+ outputs = residual
+ if return_skip:
+ return outputs, skip
+ else:
+ return outputs
+
+
+def relu6(x):
+ return K.relu(x, max_value=6)
+
+
+def _make_divisible(v, divisor, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1):
+ in_channels = inputs.shape[-1]
+ pointwise_conv_filters = int(filters * alpha)
+ pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
+ x = inputs
+ prefix = 'expanded_conv_{}_'.format(block_id)
+ if block_id:
+ # Expand
+
+ x = Conv2D(expansion * in_channels, kernel_size=1, padding='same',
+ use_bias=False, activation=None,
+ name=prefix + 'expand')(x)
+ x = BatchNormalization(epsilon=1e-3, momentum=0.999,
+ name=prefix + 'expand_BN')(x)
+ x = Activation(relu6, name=prefix + 'expand_relu')(x)
+ else:
+ prefix = 'expanded_conv_'
+ # Depthwise
+ x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None,
+ use_bias=False, padding='same', dilation_rate=(rate, rate),
+ name=prefix + 'depthwise')(x)
+ x = BatchNormalization(epsilon=1e-3, momentum=0.999,
+ name=prefix + 'depthwise_BN')(x)
+
+ x = Activation(relu6, name=prefix + 'depthwise_relu')(x)
+
+ # Project
+ x = Conv2D(pointwise_filters,
+ kernel_size=1, padding='same', use_bias=False, activation=None,
+ name=prefix + 'project')(x)
+ x = BatchNormalization(epsilon=1e-3, momentum=0.999,
+ name=prefix + 'project_BN')(x)
+
+ if skip_connection:
+ return Add(name=prefix + 'add')([inputs, x])
+
+ # if in_channels == pointwise_filters and stride == 1:
+ # return Add(name='res_connect_' + str(block_id))([inputs, x])
+
+ return x
+
+
+def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone='mobilenetv2'
+ , OS=16, alpha=1.):
+ """ Instantiates the Deeplabv3+ architecture
+
+ Optionally loads weights pre-trained
+ on PASCAL VOC. This model is available for TensorFlow only,
+ and can only be used with inputs following the TensorFlow
+ data format `(width, height, channels)`.
+ # Arguments
+ weights: one of 'pascal_voc' (pre-trained on pascal voc)
+ or None (random initialization)
+ input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
+ to use as image input for the model.
+ input_shape: shape of input image. format HxWxC
+ PASCAL VOC model was trained on (512,512,3) images
+ classes: number of desired classes. If classes != 21,
+ last layer is initialized randomly
+ backbone: backbone to use. one of {'xception','mobilenetv2'}
+ OS: determines input_shape/feature_extractor_output ratio. One of {8,16}.
+ Used only for xception backbone.
+ alpha: controls the width of the MobileNetV2 network. This is known as the
+ width multiplier in the MobileNetV2 paper.
+ - If `alpha` < 1.0, proportionally decreases the number
+ of filters in each layer.
+ - If `alpha` > 1.0, proportionally increases the number
+ of filters in each layer.
+ - If `alpha` = 1, default number of filters from the paper
+ are used at each layer.
+ Used only for mobilenetv2 backbone
+
+ # Returns
+ A Keras model instance.
+
+ # Raises
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ ValueError: in case of invalid argument for `weights` or `backbone`
+
+ """
+
+ if not (weights in {'pascal_voc', 'cityscapes', None}):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `pascal_voc`, or `cityscapes` '
+ '(pre-trained on PASCAL VOC)')
+
+ if K.backend() != 'tensorflow':
+ raise RuntimeError('The Deeplabv3+ model is only available with '
+ 'the TensorFlow backend.')
+
+ if not (backbone in {'xception', 'mobilenetv2'}):
+ raise ValueError('The `backbone` argument should be either '
+ '`xception` or `mobilenetv2` ')
+
+ if input_tensor is None:
+ img_input = Input(shape=input_shape)
+ else:
+ if not K.is_keras_tensor(input_tensor):
+ # Input layer
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ if backbone == 'xception':
+ if OS == 8:
+ entry_block3_stride = 1
+ middle_block_rate = 2 # ! Not mentioned in paper, but required
+ exit_block_rates = (2, 4)
+ atrous_rates = (12, 24, 36)
+ else:
+ entry_block3_stride = 2
+ middle_block_rate = 1
+ exit_block_rates = (1, 2)
+ atrous_rates = (6, 12, 18)
+
+ x = Conv2D(32, (3, 3), strides=(2, 2),
+ name='entry_flow_conv1_1', use_bias=False, padding='same')(img_input)
+ x = BatchNormalization(name='entry_flow_conv1_1_BN')(x)
+ x = Activation('relu')(x)
+
+ x = _conv2d_same(x, 64, 'entry_flow_conv1_2', kernel_size=3, stride=1)
+ x = BatchNormalization(name='entry_flow_conv1_2_BN')(x)
+ x = Activation('relu')(x)
+
+ x = _xception_block(x, [128, 128, 128], 'entry_flow_block1',
+ skip_connection_type='conv', stride=2,
+ depth_activation=False)
+ x, skip1 = _xception_block(x, [256, 256, 256], 'entry_flow_block2',
+ skip_connection_type='conv', stride=2,
+ depth_activation=False, return_skip=True)
+
+ x = _xception_block(x, [728, 728, 728], 'entry_flow_block3',
+ skip_connection_type='conv', stride=entry_block3_stride,
+ depth_activation=False)
+ for i in range(16):
+ x = _xception_block(x, [728, 728, 728], 'middle_flow_unit_{}'.format(i + 1),
+ skip_connection_type='sum', stride=1, rate=middle_block_rate,
+ depth_activation=False)
+
+ x = _xception_block(x, [728, 1024, 1024], 'exit_flow_block1',
+ skip_connection_type='conv', stride=1, rate=exit_block_rates[0],
+ depth_activation=False)
+ x = _xception_block(x, [1536, 1536, 2048], 'exit_flow_block2',
+ skip_connection_type='none', stride=1, rate=exit_block_rates[1],
+ depth_activation=True)
+
+ else:
+ OS = 8
+ first_block_filters = _make_divisible(32 * alpha, 8)
+ x = Conv2D(first_block_filters,
+ kernel_size=3,
+ strides=(2, 2), padding='same',
+ use_bias=False, name='Conv')(img_input)
+ x = BatchNormalization(
+ epsilon=1e-3, momentum=0.999, name='Conv_BN')(x)
+ x = Activation(relu6, name='Conv_Relu6')(x)
+
+ x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1,
+ expansion=1, block_id=0, skip_connection=False)
+
+ x = _inverted_res_block(x, filters=24, alpha=alpha, stride=2,
+ expansion=6, block_id=1, skip_connection=False)
+ x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1,
+ expansion=6, block_id=2, skip_connection=True)
+
+ x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2,
+ expansion=6, block_id=3, skip_connection=False)
+ x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1,
+ expansion=6, block_id=4, skip_connection=True)
+ x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1,
+ expansion=6, block_id=5, skip_connection=True)
+
+ # stride in block 6 changed from 2 -> 1, so we need to use rate = 2
+ x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, # 1!
+ expansion=6, block_id=6, skip_connection=False)
+ x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=7, skip_connection=True)
+ x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=8, skip_connection=True)
+ x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=9, skip_connection=True)
+
+ x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=10, skip_connection=False)
+ x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=11, skip_connection=True)
+ x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2,
+ expansion=6, block_id=12, skip_connection=True)
+
+ x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=2, # 1!
+ expansion=6, block_id=13, skip_connection=False)
+ x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4,
+ expansion=6, block_id=14, skip_connection=True)
+ x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4,
+ expansion=6, block_id=15, skip_connection=True)
+
+ x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, rate=4,
+ expansion=6, block_id=16, skip_connection=False)
+
+ # end of feature extractor
+
+ # branching for Atrous Spatial Pyramid Pooling
+
+ # Image Feature branch
+ #out_shape = int(np.ceil(input_shape[0] / OS))
+ b4 = AveragePooling2D(pool_size=(int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(x)
+ b4 = Conv2D(256, (1, 1), padding='same',
+ use_bias=False, name='image_pooling')(b4)
+ b4 = BatchNormalization(name='image_pooling_BN', epsilon=1e-5)(b4)
+ b4 = Activation('relu')(b4)
+ b4 = BilinearUpsampling((int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(b4)
+
+ # simple 1x1
+ b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp0')(x)
+ b0 = BatchNormalization(name='aspp0_BN', epsilon=1e-5)(b0)
+ b0 = Activation('relu', name='aspp0_activation')(b0)
+
+ # there are only 2 branches in mobilenetV2. not sure why
+ if backbone == 'xception':
+ # rate = 6 (12)
+ b1 = SepConv_BN(x, 256, 'aspp1',
+ rate=atrous_rates[0], depth_activation=True, epsilon=1e-5)
+ # rate = 12 (24)
+ b2 = SepConv_BN(x, 256, 'aspp2',
+ rate=atrous_rates[1], depth_activation=True, epsilon=1e-5)
+ # rate = 18 (36)
+ b3 = SepConv_BN(x, 256, 'aspp3',
+ rate=atrous_rates[2], depth_activation=True, epsilon=1e-5)
+
+ # concatenate ASPP branches & project
+ x = Concatenate()([b4, b0, b1, b2, b3])
+ else:
+ x = Concatenate()([b4, b0])
+
+ x = Conv2D(256, (1, 1), padding='same',
+ use_bias=False, name='concat_projection')(x)
+ x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
+ x = Activation('relu')(x)
+ x = Dropout(0.1)(x)
+
+ # DeepLab v.3+ decoder
+
+ if backbone == 'xception':
+ # Feature projection
+ # x4 (x2) block
+ x = BilinearUpsampling(output_size=(int(np.ceil(input_shape[0] / 4)),
+ int(np.ceil(input_shape[1] / 4))))(x)
+ dec_skip1 = Conv2D(48, (1, 1), padding='same',
+ use_bias=False, name='feature_projection0')(skip1)
+ dec_skip1 = BatchNormalization(
+ name='feature_projection0_BN', epsilon=1e-5)(dec_skip1)
+ dec_skip1 = Activation('relu')(dec_skip1)
+ x = Concatenate()([x, dec_skip1])
+ x = SepConv_BN(x, 256, 'decoder_conv0',
+ depth_activation=True, epsilon=1e-5)
+ x = SepConv_BN(x, 256, 'decoder_conv1',
+ depth_activation=True, epsilon=1e-5)
+
+ # you can use it with arbitary number of classes
+ if classes == 21:
+ last_layer_name = 'logits_semantic'
+ else:
+ last_layer_name = 'custom_logits_semantic'
+
+ x = Conv2D(classes, (1, 1), padding='same', name=last_layer_name)(x)
+ x = BilinearUpsampling(output_size=(input_shape[0], input_shape[1]))(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ model = Model(inputs, x, name='deeplabv3plus')
+
+ # load weights
+
+ if weights == 'pascal_voc':
+ if backbone == 'xception':
+ weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels.h5',
+ WEIGHTS_PATH_X,
+ cache_subdir='models')
+ else:
+ weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5',
+ WEIGHTS_PATH_MOBILE,
+ cache_subdir='models')
+ model.load_weights(weights_path, by_name=True)
+ elif weights == 'cityscapes':
+ if backbone == 'xception':
+ weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5',
+ WEIGHTS_PATH_X_CS,
+ cache_subdir='models')
+ else:
+ weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5',
+ WEIGHTS_PATH_MOBILE_CS,
+ cache_subdir='models')
+ model.load_weights(weights_path, by_name=True)
+ return model
+
+
+def preprocess_input(x):
+ """Preprocesses a numpy array encoding a batch of images.
+ # Arguments
+ x: a 4D numpy array consists of RGB values within [0, 255].
+ # Returns
+ Input array scaled to [-1.,1.]
+ """
+ return imagenet_utils.preprocess_input(x, mode='tf')
diff --git a/models/unets.py b/models/unets.py
new file mode 100644
index 0000000000000000000000000000000000000000..2879efa5bfea642bfa6bfa30ceaecacdb29f69d3
--- /dev/null
+++ b/models/unets.py
@@ -0,0 +1,171 @@
+from keras.models import Model
+from keras.layers import Input
+from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D
+
+
+class Unet2D:
+
+ def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels):
+ self.input_dim_x = input_dim_x
+ self.input_dim_y = input_dim_y
+ self.n_filters = n_filters
+ self.num_channels = num_channels
+
+ def get_unet_model_5_levels(self):
+ unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
+
+ conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input)
+ conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1)
+ conv1 = BatchNormalization()(conv1)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+ conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(pool1)
+ conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv2)
+ conv2 = BatchNormalization()(conv2)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+ conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool2)
+ conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv3)
+ conv3 = BatchNormalization()(conv3)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+ conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool3)
+ conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv4)
+ conv4 = BatchNormalization()(conv4)
+ drop4 = Dropout(0.5)(conv4)
+ pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
+
+ conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool4)
+ conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv5)
+ conv5 = BatchNormalization()(conv5)
+ drop5 = Dropout(0.5)(conv5)
+
+ up6 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5))
+ concat6 = Concatenate()([drop4, up6])
+ conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat6)
+ conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv6)
+ conv6 = BatchNormalization()(conv6)
+
+ up7 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
+ concat7 = Concatenate()([conv3, up7])
+ conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat7)
+ conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv7)
+ conv7 = BatchNormalization()(conv7)
+
+ up8 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))
+ concat8 = Concatenate()([conv2, up8])
+ conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat8)
+ conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv8)
+ conv8 = BatchNormalization()(conv8)
+
+ up9 = Conv2D(self.n_filters*2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))
+ concat9 = Concatenate()([conv1, up9])
+ conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(concat9)
+ conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv9)
+ conv9 = BatchNormalization()(conv9)
+
+ conv10 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv9)
+
+ return Model(outputs=conv10, inputs=unet_input), 'unet_model_5_levels'
+
+
+ def get_unet_model_4_levels(self):
+ unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
+
+ conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(unet_input)
+ conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv1)
+ conv1 = BatchNormalization()(conv1)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+ conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool1)
+ conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv2)
+ conv2 = BatchNormalization()(conv2)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+ conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool2)
+ conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv3)
+ conv3 = BatchNormalization()(conv3)
+ drop3 = Dropout(0.5)(conv3)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(drop3)
+
+ conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool3)
+ conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv4)
+ conv4 = BatchNormalization()(conv4)
+ drop4 = Dropout(0.5)(conv4)
+
+ up5 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop4))
+ concat5 = Concatenate()([drop3, up5])
+ conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat5)
+ conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv5)
+ conv5 = BatchNormalization()(conv5)
+
+ up6 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
+ concat6 = Concatenate()([conv2, up6])
+ conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat6)
+ conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv6)
+ conv6 = BatchNormalization()(conv6)
+
+ up7 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
+ concat7 = Concatenate()([conv1, up7])
+ conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat7)
+ conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv7)
+ conv7 = BatchNormalization()(conv7)
+
+ conv9 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv7)
+
+ return Model(outputs=conv9, inputs=unet_input), 'unet_model_4_levels'
+
+
+ def get_unet_model_yuanqing(self):
+ # Model inspired by https://github.com/yuanqing811/ISIC2018
+ unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels))
+
+ conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input)
+ conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1)
+ pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
+
+ conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(pool1)
+ conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv2)
+ pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
+
+ conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(pool2)
+ conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3)
+ conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3)
+ pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
+
+ conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool3)
+ conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4)
+ conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4)
+ pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
+
+ conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool4)
+ conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5)
+ conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5)
+
+ up6 = Conv2D(self.n_filters * 4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5))
+ feature4 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv4)
+ concat6 = Concatenate()([feature4, up6])
+ conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(concat6)
+ conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv6)
+
+ up7 = Conv2D(self.n_filters * 2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6))
+ feature3 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv3)
+ concat7 = Concatenate()([feature3, up7])
+ conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(concat7)
+ conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv7)
+
+ up8 = Conv2D(self.n_filters * 1, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7))
+ feature2 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv2)
+ concat8 = Concatenate()([feature2, up8])
+ conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(concat8)
+ conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv8)
+
+ up9 = Conv2D(int(self.n_filters / 2), 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8))
+ feature1 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv1)
+ concat9 = Concatenate()([feature1, up9])
+ conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(concat9)
+ conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv9)
+ conv9 = Conv2D(3, kernel_size=3, activation='relu', padding='same')(conv9)
+ conv10 = Conv2D(1, kernel_size=1, activation='sigmoid')(conv9)
+
+ return Model(outputs=conv10, inputs=unet_input), 'unet_model_yuanqing'
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5c11d50d7849f326b4c3107dde831783c50d4cb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,151 @@
+absl-py==2.3.1
+aiofiles==24.1.0
+annotated-types==0.7.0
+anyio==4.10.0
+asttokens==3.0.0
+astunparse==1.6.3
+attrs==25.3.0
+beautifulsoup4==4.13.4
+blinker==1.9.0
+Brotli==1.1.0
+cachetools==5.5.2
+certifi==2025.8.3
+charset-normalizer==3.4.2
+click==8.2.1
+colorama==0.4.6
+comm==0.2.3
+ConfigArgParse==1.7.1
+contourpy==1.3.2
+cycler==0.12.1
+dash==3.2.0
+decorator==5.2.1
+exceptiongroup==1.3.0
+executing==2.2.0
+fastapi==0.116.1
+fastjsonschema==2.21.1
+ffmpy==0.6.1
+filelock==3.18.0
+Flask==3.1.1
+flatbuffers==25.2.10
+fonttools==4.59.0
+fsspec==2025.7.0
+gast==0.4.0
+google-generativeai
+gdown==5.2.0
+google-auth==2.40.3
+google-auth-oauthlib==0.4.6
+google-pasta==0.2.0
+gradio==5.41.1
+gradio_client==1.11.0
+gradio_imageslider==0.0.20
+groovy==0.1.2
+grpcio==1.74.0
+h11==0.16.0
+h5py==3.14.0
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.34.3
+idna==3.10
+imageio==2.37.0
+importlib_metadata==8.7.0
+ipython==8.37.0
+ipywidgets==8.1.7
+itsdangerous==2.2.0
+jedi==0.19.2
+Jinja2==3.1.6
+jsonschema==4.25.0
+jsonschema-specifications==2025.4.1
+jupyter_core==5.8.1
+jupyterlab_widgets==3.0.15
+keras==2.10.0
+Keras-Preprocessing==1.1.2
+kiwisolver==1.4.8
+lazy_loader==0.4
+libclang==18.1.1
+Markdown==3.8.2
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+matplotlib==3.10.5
+matplotlib-inline==0.1.7
+mdurl==0.1.2
+mpmath==1.3.0
+narwhals==2.0.1
+nbformat==5.10.4
+nest-asyncio==1.6.0
+networkx==3.4.2
+numpy==1.26.4
+oauthlib==3.3.1
+open3d==0.19.0
+opencv-python==4.11.0.86
+opt_einsum==3.4.0
+orjson==3.11.1
+packaging==25.0
+pandas==2.3.1
+parso==0.8.4
+pillow==11.3.0
+platformdirs==4.3.8
+plotly==6.2.0
+prompt_toolkit==3.0.51
+protobuf==3.19.6
+psutil==5.9.8
+pure_eval==0.2.3
+pyasn1==0.6.1
+pyasn1_modules==0.4.2
+pydantic==2.10.6
+pydantic_core==2.27.2
+pydub==0.25.1
+Pygments==2.19.2
+pyparsing==3.2.3
+PySocks==1.7.1
+python-dateutil==2.9.0.post0
+python-multipart==0.0.20
+pytz==2025.2
+PyYAML==6.0.2
+referencing==0.36.2
+requests==2.32.4
+requests-oauthlib==2.0.0
+retrying==1.4.2
+rich==14.1.0
+rpds-py==0.27.0
+rsa==4.9.1
+ruff==0.12.7
+safehttpx==0.1.6
+scikit-image==0.25.2
+scipy==1.15.3
+semantic-version==2.10.0
+shellingham==1.5.4
+six==1.17.0
+sniffio==1.3.1
+soupsieve==2.7
+spaces==0.39.0
+stack-data==0.6.3
+starlette==0.47.2
+sympy==1.14.0
+tensorboard==2.10.1
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+tensorflow==2.10.1
+tensorflow-estimator==2.10.0
+tensorflow-hub==0.16.1
+tensorflow-io-gcs-filesystem==0.31.0
+termcolor==3.1.0
+tf-keras==2.15.0
+tifffile==2025.5.10
+tomlkit==0.13.3
+torch==2.8.0
+torchvision==0.23.0
+tqdm==4.67.1
+traitlets==5.14.3
+typer==0.16.0
+typing-inspection==0.4.1
+typing_extensions==4.14.1
+tzdata==2025.2
+urllib3==2.5.0
+uvicorn==0.35.0
+wcwidth==0.2.13
+websockets==15.0.1
+Werkzeug==3.1.3
+widgetsnbextension==4.0.14
+wrapt==1.17.2
+zipp==3.23.0
+transformers
\ No newline at end of file
diff --git a/temp_files/Final_workig_cpu.txt b/temp_files/Final_workig_cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..186da1027d63eb08523dbd57c115344994a0bb7a
--- /dev/null
+++ b/temp_files/Final_workig_cpu.txt
@@ -0,0 +1,1000 @@
+import glob
+import gradio as gr
+import matplotlib
+import numpy as np
+from PIL import Image
+import torch
+import tempfile
+from gradio_imageslider import ImageSlider
+import plotly.graph_objects as go
+import plotly.express as px
+import open3d as o3d
+from depth_anything_v2.dpt import DepthAnythingV2
+import os
+import tensorflow as tf
+from tensorflow.keras.models import load_model
+from tensorflow.keras.preprocessing import image as keras_image
+import base64
+from io import BytesIO
+import gdown
+import spaces
+import cv2
+
+# Import actual segmentation model components
+from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.io.data import normalize
+
+# Define path and file ID
+checkpoint_dir = "checkpoints"
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
+gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
+
+# Download if not already present
+if not os.path.exists(model_file):
+ print("Downloading model from Google Drive...")
+ gdown.download(gdrive_url, model_file, quiet=False)
+
+# --- TensorFlow: Check GPU Availability ---
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+ print("TensorFlow is using GPU")
+else:
+ print("TensorFlow is using CPU")
+
+# --- Load Wound Classification Model and Class Labels ---
+wound_model = load_model("keras_model.h5")
+with open("labels.txt", "r") as f:
+ class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
+
+# --- Load Actual Wound Segmentation Model ---
+class WoundSegmentationModel:
+ def __init__(self):
+ self.input_dim_x = 224
+ self.input_dim_y = 224
+ self.model = None
+ self.load_model()
+
+ def load_model(self):
+ """Load the trained wound segmentation model"""
+ try:
+ # Try to load the most recent model
+ weight_file_name = '2025-08-07_16-25-27.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e:
+ print(f"Error loading segmentation model: {e}")
+ # Fallback to the older model
+ try:
+ weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e2:
+ print(f"Error loading fallback segmentation model: {e2}")
+ self.model = None
+
+ def preprocess_image(self, image):
+ """Preprocess the uploaded image for model input"""
+ if image is None:
+ return None
+
+ # Convert to RGB if needed
+ if len(image.shape) == 3 and image.shape[2] == 3:
+ # Convert BGR to RGB if needed
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Resize to model input size
+ image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
+
+ # Normalize the image
+ image = image.astype(np.float32) / 255.0
+
+ # Add batch dimension
+ image = np.expand_dims(image, axis=0)
+
+ return image
+
+ def postprocess_prediction(self, prediction):
+ """Postprocess the model prediction"""
+ # Remove batch dimension
+ prediction = prediction[0]
+
+ # Apply threshold to get binary mask
+ threshold = 0.5
+ binary_mask = (prediction > threshold).astype(np.uint8) * 255
+
+ return binary_mask
+
+ def segment_wound(self, input_image):
+ """Main function to segment wound from uploaded image"""
+ if self.model is None:
+ return None, "Error: Segmentation model not loaded. Please check the model files."
+
+ if input_image is None:
+ return None, "Please upload an image."
+
+ try:
+ # Preprocess the image
+ processed_image = self.preprocess_image(input_image)
+
+ if processed_image is None:
+ return None, "Error processing image."
+
+ # Make prediction
+ prediction = self.model.predict(processed_image, verbose=0)
+
+ # Postprocess the prediction
+ segmented_mask = self.postprocess_prediction(prediction)
+
+ return segmented_mask, "Segmentation completed successfully!"
+
+ except Exception as e:
+ return None, f"Error during segmentation: {str(e)}"
+
+# Initialize the segmentation model
+segmentation_model = WoundSegmentationModel()
+
+# --- PyTorch: Set Device and Load Depth Model ---
+map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
+print(f"Using PyTorch device: {map_device}")
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+}
+encoder = 'vitl'
+depth_model = DepthAnythingV2(**model_configs[encoder])
+state_dict = torch.load(
+ f'checkpoints/depth_anything_v2_{encoder}.pth',
+ map_location=map_device
+)
+depth_model.load_state_dict(state_dict)
+depth_model = depth_model.to(map_device).eval()
+
+
+# --- Custom CSS for unified dark theme ---
+css = """
+.gradio-container {
+ font-family: 'Segoe UI', sans-serif;
+ background-color: #121212;
+ color: #ffffff;
+ padding: 20px;
+}
+.gr-button {
+ background-color: #2c3e50;
+ color: white;
+ border-radius: 10px;
+}
+.gr-button:hover {
+ background-color: #34495e;
+}
+.gr-html, .gr-html div {
+ white-space: normal !important;
+ overflow: visible !important;
+ text-overflow: unset !important;
+ word-break: break-word !important;
+}
+#img-display-container {
+ max-height: 100vh;
+}
+#img-display-input {
+ max-height: 80vh;
+}
+#img-display-output {
+ max-height: 80vh;
+}
+#download {
+ height: 62px;
+}
+h1 {
+ text-align: center;
+ font-size: 3rem;
+ font-weight: bold;
+ margin: 2rem 0;
+ color: #ffffff;
+}
+h2 {
+ color: #ffffff;
+ text-align: center;
+ margin: 1rem 0;
+}
+.gr-tabs {
+ background-color: #1e1e1e;
+ border-radius: 10px;
+ padding: 10px;
+}
+.gr-tab-nav {
+ background-color: #2c3e50;
+ border-radius: 8px;
+}
+.gr-tab-nav button {
+ color: #ffffff !important;
+}
+.gr-tab-nav button.selected {
+ background-color: #34495e !important;
+}
+"""
+
+# --- Wound Classification Functions ---
+def preprocess_input(img):
+ img = img.resize((224, 224))
+ arr = keras_image.img_to_array(img)
+ arr = arr / 255.0
+ return np.expand_dims(arr, axis=0)
+
+def get_reasoning_from_gemini(img, prediction):
+ try:
+ # For now, return a simple explanation without Gemini API to avoid typing issues
+ # In production, you would implement the proper Gemini API call here
+ explanations = {
+ "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
+ "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
+ "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
+ "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
+ "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
+ }
+
+ return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
+
+ except Exception as e:
+ return f"(Reasoning unavailable: {str(e)})"
+
+@spaces.GPU
+def classify_wound_image(img):
+ if img is None:
+ return "No image provided
", ""
+
+ img_array = preprocess_input(img)
+ predictions = wound_model.predict(img_array, verbose=0)[0]
+ pred_idx = int(np.argmax(predictions))
+ pred_class = class_labels[pred_idx]
+
+ # Get reasoning from Gemini
+ reasoning_text = get_reasoning_from_gemini(img, pred_class)
+
+ # Prediction Card
+ predicted_card = f"""
+
+
+ Predicted Wound Type
+
+
+ {pred_class}
+
+
+ """
+
+ # Reasoning Card
+ reasoning_card = f"""
+
+
+ Reasoning
+
+
+ {reasoning_text}
+
+
+ """
+
+ return predicted_card, reasoning_card
+
+# --- Wound Severity Estimation Functions ---
+@spaces.GPU
+def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
+ """Compute area statistics for different depth regions"""
+ pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
+
+ # Extract only wound region
+ wound_mask = (mask > 127)
+ wound_depths = depth_map[wound_mask]
+ total_area = np.sum(wound_mask) * pixel_area_cm2
+
+ # Categorize depth regions
+ shallow = wound_depths < 3
+ moderate = (wound_depths >= 3) & (wound_depths < 6)
+ deep = wound_depths >= 6
+
+ shallow_area = np.sum(shallow) * pixel_area_cm2
+ moderate_area = np.sum(moderate) * pixel_area_cm2
+ deep_area = np.sum(deep) * pixel_area_cm2
+
+ deep_ratio = deep_area / total_area if total_area > 0 else 0
+
+ return {
+ 'total_area_cm2': total_area,
+ 'shallow_area_cm2': shallow_area,
+ 'moderate_area_cm2': moderate_area,
+ 'deep_area_cm2': deep_area,
+ 'deep_ratio': deep_ratio,
+ 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
+ }
+
+def classify_wound_severity_by_area(depth_stats):
+ """Classify wound severity based on area and depth distribution"""
+ total = depth_stats['total_area_cm2']
+ deep = depth_stats['deep_area_cm2']
+ moderate = depth_stats['moderate_area_cm2']
+
+ if total == 0:
+ return "Unknown"
+
+ # Severity classification rules
+ if deep > 2 or (deep / total) > 0.3:
+ return "Severe"
+ elif moderate > 1.5 or (moderate / total) > 0.4:
+ return "Moderate"
+ else:
+ return "Mild"
+
+def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
+ """Analyze wound severity from depth map and wound mask"""
+ if image is None or depth_map is None or wound_mask is None:
+ return "❌ Please upload image, depth map, and wound mask."
+
+ # Convert wound mask to grayscale if needed
+ if len(wound_mask.shape) == 3:
+ wound_mask = np.mean(wound_mask, axis=2)
+
+ # Ensure depth map and mask have same dimensions
+ if depth_map.shape[:2] != wound_mask.shape[:2]:
+ # Resize mask to match depth map
+ from PIL import Image
+ mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
+ mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
+ wound_mask = np.array(mask_pil)
+
+ # Compute statistics
+ stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
+ severity = classify_wound_severity_by_area(stats)
+
+ # Create severity report with color coding
+ severity_color = {
+ "Mild": "#4CAF50", # Green
+ "Moderate": "#FF9800", # Orange
+ "Severe": "#F44336" # Red
+ }.get(severity, "#9E9E9E") # Gray for unknown
+
+ report = f"""
+
+
+ 🩹 Wound Severity Analysis
+
+
+
+
+
+ 📏 Area Measurements
+
+
+
🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
+
🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
+
🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
+
🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
+
+
+
+
+
+ 📊 Depth Analysis
+
+
+
🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
+
📏 Max Depth: {stats['max_depth']:.1f} mm
+
⚡ Pixel Spacing: {pixel_spacing_mm} mm
+
+
+
+
+
+
+ 🎯 Predicted Severity: {severity}
+
+
+ {get_severity_description(severity)}
+
+
+
+ """
+
+ return report
+
+def get_severity_description(severity):
+ """Get description for severity level"""
+ descriptions = {
+ "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
+ "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
+ "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
+ "Unknown": "Unable to determine severity due to insufficient data."
+ }
+ return descriptions.get(severity, "Severity assessment unavailable.")
+
+def create_sample_wound_mask(image_shape, center=None, radius=50):
+ """Create a sample circular wound mask for testing"""
+ if center is None:
+ center = (image_shape[1] // 2, image_shape[0] // 2)
+
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ y, x = np.ogrid[:image_shape[0], :image_shape[1]]
+
+ # Create circular mask
+ dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
+ mask[dist_from_center <= radius] = 255
+
+ return mask
+
+def create_realistic_wound_mask(image_shape, method='elliptical'):
+ """Create a more realistic wound mask with irregular shapes"""
+ h, w = image_shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+
+ if method == 'elliptical':
+ # Create elliptical wound mask
+ center = (w // 2, h // 2)
+ radius_x = min(w, h) // 3
+ radius_y = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ # Add some irregularity to make it more realistic
+ ellipse = ((x - center[0])**2 / (radius_x**2) +
+ (y - center[1])**2 / (radius_y**2)) <= 1
+
+ # Add some noise and irregularity
+ noise = np.random.random((h, w)) > 0.8
+ mask = (ellipse | noise).astype(np.uint8) * 255
+
+ elif method == 'irregular':
+ # Create irregular wound mask
+ center = (w // 2, h // 2)
+ radius = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
+
+ # Add irregular extensions
+ extensions = np.zeros_like(base_circle)
+ for i in range(3):
+ angle = i * 2 * np.pi / 3
+ ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
+ ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
+ ext_radius = radius // 3
+
+ ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
+ extensions = extensions | ext_circle
+
+ mask = (base_circle | extensions).astype(np.uint8) * 255
+
+ # Apply morphological operations to smooth the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ return mask
+
+# --- Depth Estimation Functions ---
+@spaces.GPU
+def predict_depth(image):
+ return depth_model.infer_image(image)
+
+def calculate_max_points(image):
+ """Calculate maximum points based on image dimensions (3x pixel count)"""
+ if image is None:
+ return 10000 # Default value
+ h, w = image.shape[:2]
+ max_points = h * w * 3
+ # Ensure minimum and reasonable maximum values
+ return max(1000, min(max_points, 300000))
+
+def update_slider_on_image_upload(image):
+ """Update the points slider when an image is uploaded"""
+ max_points = calculate_max_points(image)
+ default_value = min(10000, max_points // 10) # 10% of max points as default
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
+ label=f"Number of 3D points (max: {max_points:,})")
+
+@spaces.GPU
+def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
+ """Create a point cloud from depth map using camera intrinsics with high detail"""
+ h, w = depth_map.shape
+
+ # Use smaller step for higher detail (reduced downsampling)
+ step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ x_cam = (x_coords - w / 2) / focal_length_x
+ y_cam = (y_coords - h / 2) / focal_length_y
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors = image_colors.reshape(-1, 3) / 255.0
+
+ # Create Open3D point cloud
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+@spaces.GPU
+def reconstruct_surface_mesh_from_point_cloud(pcd):
+ """Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
+ # Estimate and orient normals with high precision
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
+ pcd.orient_normals_consistent_tangent_plane(k=50)
+
+ # Create surface mesh with maximum detail (depth=12 for very high resolution)
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
+
+ # Return mesh without filtering low-density vertices
+ return mesh
+
+@spaces.GPU
+def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
+ """Create an enhanced 3D visualization using proper camera projection"""
+ h, w = depth_map.shape
+
+ # Downsample to avoid too many points for performance
+ step = max(1, int(np.sqrt(h * w / max_points)))
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ focal_length = 470.4 # Default focal length
+ x_cam = (x_coords - w / 2) / focal_length
+ y_cam = (y_coords - h / 2) / focal_length
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ x_flat = x_3d.flatten()
+ y_flat = y_3d.flatten()
+ z_flat = z_3d.flatten()
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors_flat = image_colors.reshape(-1, 3)
+
+ # Create 3D scatter plot with proper camera projection
+ fig = go.Figure(data=[go.Scatter3d(
+ x=x_flat,
+ y=y_flat,
+ z=z_flat,
+ mode='markers',
+ marker=dict(
+ size=1.5,
+ color=colors_flat,
+ opacity=0.9
+ ),
+ hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
+ 'Depth: %{z:.2f}
' +
+ ''
+ )])
+
+ fig.update_layout(
+ title="3D Point Cloud Visualization (Camera Projection)",
+ scene=dict(
+ xaxis_title="X (meters)",
+ yaxis_title="Y (meters)",
+ zaxis_title="Z (meters)",
+ camera=dict(
+ eye=dict(x=2.0, y=2.0, z=2.0),
+ center=dict(x=0, y=0, z=0),
+ up=dict(x=0, y=0, z=1)
+ ),
+ aspectmode='data'
+ ),
+ width=700,
+ height=600
+ )
+
+ return fig
+
+def on_depth_submit(image, num_points, focal_x, focal_y):
+ original_image = image.copy()
+
+ h, w = image.shape[:2]
+
+ # Predict depth using the model
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+
+ # Save raw 16-bit depth
+ raw_depth = Image.fromarray(depth.astype('uint16'))
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ raw_depth.save(tmp_raw_depth.name)
+
+ # Normalize and convert to grayscale for display
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ norm_depth = norm_depth.astype(np.uint8)
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ gray_depth = Image.fromarray(norm_depth)
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ gray_depth.save(tmp_gray_depth.name)
+
+ # Create point cloud
+ pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
+
+ # Reconstruct mesh from point cloud
+ mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
+
+ # Save mesh with faces as .ply
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
+ o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
+
+ # Create enhanced 3D scatter plot visualization
+ depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
+
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
+
+# --- Actual Wound Segmentation Functions ---
+def create_automatic_wound_mask(image, method='deep_learning'):
+ """
+ Automatically generate wound mask from image using the actual deep learning model
+
+ Args:
+ image: Input image (numpy array)
+ method: Segmentation method (currently only 'deep_learning' supported)
+
+ Returns:
+ mask: Binary wound mask
+ """
+ if image is None:
+ return None
+
+ # Use the actual deep learning model for segmentation
+ if method == 'deep_learning':
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+ else:
+ # Fallback to deep learning if method not recognized
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+
+def post_process_wound_mask(mask, min_area=100):
+ """Post-process the wound mask to remove noise and small objects"""
+ if mask is None:
+ return None
+
+ # Convert to binary if needed
+ if mask.dtype != np.uint8:
+ mask = mask.astype(np.uint8)
+
+ # Apply morphological operations to clean up
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Remove small objects using OpenCV
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area >= min_area:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ # Fill holes
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
+
+ return mask_clean
+
+def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
+ """Analyze wound severity with automatic mask generation using actual segmentation model"""
+ if image is None or depth_map is None:
+ return "❌ Please provide both image and depth map."
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
+
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
+
+# --- Main Gradio Interface ---
+with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
+ gr.HTML("Wound Analysis & Depth Estimation System
")
+ gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
+
+ # Shared image state
+ shared_image = gr.State()
+
+ with gr.Tabs():
+ # Tab 1: Wound Classification
+ with gr.Tab("1. Wound Classification"):
+ gr.Markdown("### Step 1: Upload and classify your wound image")
+ gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
+
+ with gr.Column(scale=1):
+ wound_prediction_box = gr.HTML()
+ wound_reasoning_box = gr.HTML()
+
+ # Button to pass image to depth estimation
+ with gr.Row():
+ pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg")
+ pass_status = gr.HTML("")
+
+ wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
+ outputs=[wound_prediction_box, wound_reasoning_box])
+
+ # Store image when uploaded for classification
+ wound_image_input.change(
+ fn=lambda img: img,
+ inputs=[wound_image_input],
+ outputs=[shared_image]
+ )
+
+ # Tab 2: Depth Estimation
+ with gr.Tab("2. Depth Estimation & 3D Visualization"):
+ gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
+ gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
+
+ with gr.Row():
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
+
+ with gr.Row():
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
+ load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary")
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
+ label="Number of 3D points (upload image to update max)")
+
+ with gr.Row():
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length X (pixels)")
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length Y (pixels)")
+
+ with gr.Row():
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
+
+ # 3D Visualization
+ gr.Markdown("### 3D Point Cloud Visualization")
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
+
+ # Store depth map for severity analysis
+ depth_map_state = gr.State()
+
+ # Tab 3: Wound Severity Analysis
+ with gr.Tab("3. 🩹 Wound Severity Analysis"):
+ gr.Markdown("### Step 3: Analyze wound severity using depth maps")
+ gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
+
+ with gr.Row():
+ severity_input_image = gr.Image(label="Original Image", type='numpy')
+ severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
+
+ with gr.Row():
+ wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
+ severity_output = gr.HTML(label="Severity Analysis Report")
+
+ gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
+
+ with gr.Row():
+ auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
+ manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg")
+ pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
+ label="Pixel Spacing (mm/pixel)")
+
+ gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
+
+ with gr.Row():
+ # Load depth map from previous tab
+ load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary")
+
+ gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
+
+ # Update slider when image is uploaded
+ depth_input_image.change(
+ fn=update_slider_on_image_upload,
+ inputs=[depth_input_image],
+ outputs=[points_slider]
+ )
+
+ # Modified depth submit function to store depth map
+ def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
+ results = on_depth_submit(image, num_points, focal_x, focal_y)
+ # Extract depth map from results for severity analysis
+ depth_map = None
+ if image is not None:
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+ # Normalize depth for severity analysis
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth_map = norm_depth.astype(np.uint8)
+ return results + [depth_map]
+
+ depth_submit.click(on_depth_submit_with_state,
+ inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
+ outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
+
+ # Load depth map to severity tab and auto-generate mask
+ def load_depth_to_severity(depth_map, original_image):
+ if depth_map is None:
+ return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
+
+ # Auto-generate wound mask using segmentation model
+ if original_image is not None:
+ auto_mask, _ = segmentation_model.segment_wound(original_image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!"
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded successfully!"
+
+ load_depth_btn.click(
+ fn=load_depth_to_severity,
+ inputs=[depth_map_state, depth_input_image],
+ outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
+ )
+
+ # Automatic severity analysis function
+ def run_auto_severity_analysis(image, depth_map, pixel_spacing):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method='deep_learning')
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
+
+ # Post-process the mask with fixed minimum area
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
+
+ # Manual severity analysis function
+ def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+ if wound_mask is None:
+ return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)."
+
+ return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
+
+ # Connect event handlers
+ auto_severity_button.click(
+ fn=run_auto_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider],
+ outputs=[severity_output]
+ )
+
+ manual_severity_button.click(
+ fn=run_manual_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider],
+ outputs=[severity_output]
+ )
+
+
+
+ # Auto-generate mask when image is uploaded
+ def auto_generate_mask_on_image_upload(image):
+ if image is None:
+ return None, "❌ No image uploaded."
+
+ # Generate automatic wound mask using segmentation model
+ auto_mask, _ = segmentation_model.segment_wound(image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return processed_mask, "✅ Wound mask auto-generated using deep learning model!"
+ else:
+ return None, "✅ Image uploaded but no wound detected. Try uploading a different image."
+ else:
+ return None, "✅ Image uploaded but segmentation failed. Try uploading a different image."
+
+ # Load shared image from classification tab
+ def load_shared_image(shared_img):
+ if shared_img is None:
+ return gr.Image(), "❌ No image available from classification tab"
+
+ # Convert PIL image to numpy array for depth estimation
+ if hasattr(shared_img, 'convert'):
+ # It's a PIL image, convert to numpy
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab"
+ else:
+ # Already numpy array
+ return shared_img, "✅ Image loaded from classification tab"
+
+ # Auto-generate mask when image is uploaded to severity tab
+ severity_input_image.change(
+ fn=auto_generate_mask_on_image_upload,
+ inputs=[severity_input_image],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+ load_shared_btn.click(
+ fn=load_shared_image,
+ inputs=[shared_image],
+ outputs=[depth_input_image, gr.HTML()]
+ )
+
+ # Pass image to depth tab function
+ def pass_image_to_depth(img):
+ if img is None:
+ return "❌ No image uploaded in classification tab"
+ return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
+
+ pass_to_depth_btn.click(
+ fn=pass_image_to_depth,
+ inputs=[shared_image],
+ outputs=[pass_status]
+ )
+
+if __name__ == '__main__':
+ demo.queue().launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True
+ )
\ No newline at end of file
diff --git a/temp_files/README.md b/temp_files/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8b38f5541ad8472e1343f988edda78b9491014af
--- /dev/null
+++ b/temp_files/README.md
@@ -0,0 +1,12 @@
+---
+title: Wound Analysis V22
+emoji: 📉
+colorFrom: purple
+colorTo: green
+sdk: gradio
+sdk_version: 5.41.1
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/temp_files/fw2.txt b/temp_files/fw2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a9835bcb758f150c00ec9dcf322a5e92f0923c81
--- /dev/null
+++ b/temp_files/fw2.txt
@@ -0,0 +1,1175 @@
+import glob
+import gradio as gr
+import matplotlib
+import numpy as np
+from PIL import Image
+import torch
+import tempfile
+from gradio_imageslider import ImageSlider
+import plotly.graph_objects as go
+import plotly.express as px
+import open3d as o3d
+from depth_anything_v2.dpt import DepthAnythingV2
+import os
+import tensorflow as tf
+from tensorflow.keras.models import load_model
+from tensorflow.keras.preprocessing import image as keras_image
+import base64
+from io import BytesIO
+import gdown
+import spaces
+import cv2
+
+# Import actual segmentation model components
+from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.io.data import normalize
+
+# Define path and file ID
+checkpoint_dir = "checkpoints"
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
+gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
+
+# Download if not already present
+if not os.path.exists(model_file):
+ print("Downloading model from Google Drive...")
+ gdown.download(gdrive_url, model_file, quiet=False)
+
+# --- TensorFlow: Check GPU Availability ---
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+ print("TensorFlow is using GPU")
+else:
+ print("TensorFlow is using CPU")
+
+# --- Load Wound Classification Model and Class Labels ---
+wound_model = load_model("keras_model.h5")
+with open("labels.txt", "r") as f:
+ class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
+
+# --- Load Actual Wound Segmentation Model ---
+class WoundSegmentationModel:
+ def __init__(self):
+ self.input_dim_x = 224
+ self.input_dim_y = 224
+ self.model = None
+ self.load_model()
+
+ def load_model(self):
+ """Load the trained wound segmentation model"""
+ try:
+ # Try to load the most recent model
+ weight_file_name = '2025-08-07_16-25-27.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e:
+ print(f"Error loading segmentation model: {e}")
+ # Fallback to the older model
+ try:
+ weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Segmentation model loaded successfully from {model_path}")
+ except Exception as e2:
+ print(f"Error loading fallback segmentation model: {e2}")
+ self.model = None
+
+ def preprocess_image(self, image):
+ """Preprocess the uploaded image for model input"""
+ if image is None:
+ return None
+
+ # Convert to RGB if needed
+ if len(image.shape) == 3 and image.shape[2] == 3:
+ # Convert BGR to RGB if needed
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Resize to model input size
+ image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
+
+ # Normalize the image
+ image = image.astype(np.float32) / 255.0
+
+ # Add batch dimension
+ image = np.expand_dims(image, axis=0)
+
+ return image
+
+ def postprocess_prediction(self, prediction):
+ """Postprocess the model prediction"""
+ # Remove batch dimension
+ prediction = prediction[0]
+
+ # Apply threshold to get binary mask
+ threshold = 0.5
+ binary_mask = (prediction > threshold).astype(np.uint8) * 255
+
+ return binary_mask
+
+ def segment_wound(self, input_image):
+ """Main function to segment wound from uploaded image"""
+ if self.model is None:
+ return None, "Error: Segmentation model not loaded. Please check the model files."
+
+ if input_image is None:
+ return None, "Please upload an image."
+
+ try:
+ # Preprocess the image
+ processed_image = self.preprocess_image(input_image)
+
+ if processed_image is None:
+ return None, "Error processing image."
+
+ # Make prediction
+ prediction = self.model.predict(processed_image, verbose=0)
+
+ # Postprocess the prediction
+ segmented_mask = self.postprocess_prediction(prediction)
+
+ return segmented_mask, "Segmentation completed successfully!"
+
+ except Exception as e:
+ return None, f"Error during segmentation: {str(e)}"
+
+# Initialize the segmentation model
+segmentation_model = WoundSegmentationModel()
+
+# --- PyTorch: Set Device and Load Depth Model ---
+map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
+print(f"Using PyTorch device: {map_device}")
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+}
+encoder = 'vitl'
+depth_model = DepthAnythingV2(**model_configs[encoder])
+state_dict = torch.load(
+ f'checkpoints/depth_anything_v2_{encoder}.pth',
+ map_location=map_device
+)
+depth_model.load_state_dict(state_dict)
+depth_model = depth_model.to(map_device).eval()
+
+
+# --- Custom CSS for unified dark theme ---
+css = """
+.gradio-container {
+ font-family: 'Segoe UI', sans-serif;
+ background-color: #121212;
+ color: #ffffff;
+ padding: 20px;
+}
+.gr-button {
+ background-color: #2c3e50;
+ color: white;
+ border-radius: 10px;
+}
+.gr-button:hover {
+ background-color: #34495e;
+}
+.gr-html, .gr-html div {
+ white-space: normal !important;
+ overflow: visible !important;
+ text-overflow: unset !important;
+ word-break: break-word !important;
+}
+#img-display-container {
+ max-height: 100vh;
+}
+#img-display-input {
+ max-height: 80vh;
+}
+#img-display-output {
+ max-height: 80vh;
+}
+#download {
+ height: 62px;
+}
+h1 {
+ text-align: center;
+ font-size: 3rem;
+ font-weight: bold;
+ margin: 2rem 0;
+ color: #ffffff;
+}
+h2 {
+ color: #ffffff;
+ text-align: center;
+ margin: 1rem 0;
+}
+.gr-tabs {
+ background-color: #1e1e1e;
+ border-radius: 10px;
+ padding: 10px;
+}
+.gr-tab-nav {
+ background-color: #2c3e50;
+ border-radius: 8px;
+}
+.gr-tab-nav button {
+ color: #ffffff !important;
+}
+.gr-tab-nav button.selected {
+ background-color: #34495e !important;
+}
+"""
+
+# --- Wound Classification Functions ---
+def preprocess_input(img):
+ img = img.resize((224, 224))
+ arr = keras_image.img_to_array(img)
+ arr = arr / 255.0
+ return np.expand_dims(arr, axis=0)
+
+def get_reasoning_from_gemini(img, prediction):
+ try:
+ # For now, return a simple explanation without Gemini API to avoid typing issues
+ # In production, you would implement the proper Gemini API call here
+ explanations = {
+ "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
+ "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
+ "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
+ "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
+ "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
+ }
+
+ return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
+
+ except Exception as e:
+ return f"(Reasoning unavailable: {str(e)})"
+
+@spaces.GPU
+def classify_wound_image(img):
+ if img is None:
+ return "No image provided
", ""
+
+ img_array = preprocess_input(img)
+ predictions = wound_model.predict(img_array, verbose=0)[0]
+ pred_idx = int(np.argmax(predictions))
+ pred_class = class_labels[pred_idx]
+
+ # Get reasoning from Gemini
+ reasoning_text = get_reasoning_from_gemini(img, pred_class)
+
+ # Prediction Card
+ predicted_card = f"""
+
+
+ Predicted Wound Type
+
+
+ {pred_class}
+
+
+ """
+
+ # Reasoning Card
+ reasoning_card = f"""
+
+
+ Reasoning
+
+
+ {reasoning_text}
+
+
+ """
+
+ return predicted_card, reasoning_card
+
+# --- Enhanced Wound Severity Estimation Functions ---
+@spaces.GPU
+def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
+ """
+ Enhanced depth analysis with proper calibration and medical standards
+ Based on wound depth classification standards:
+ - Superficial: 0-2mm (epidermis only)
+ - Partial thickness: 2-4mm (epidermis + partial dermis)
+ - Full thickness: 4-6mm (epidermis + full dermis)
+ - Deep: >6mm (involving subcutaneous tissue)
+ """
+ # Convert pixel spacing to mm
+ pixel_spacing_mm = float(pixel_spacing_mm)
+
+ # Calculate pixel area in cm²
+ pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
+
+ # Extract wound region (binary mask)
+ wound_mask = (mask > 127).astype(np.uint8)
+
+ # Apply morphological operations to clean the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
+ wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel)
+
+ # Get depth values only for wound region
+ wound_depths = depth_map[wound_mask > 0]
+
+ if len(wound_depths) == 0:
+ return {
+ 'total_area_cm2': 0,
+ 'superficial_area_cm2': 0,
+ 'partial_thickness_area_cm2': 0,
+ 'full_thickness_area_cm2': 0,
+ 'deep_area_cm2': 0,
+ 'mean_depth_mm': 0,
+ 'max_depth_mm': 0,
+ 'depth_std_mm': 0,
+ 'deep_ratio': 0,
+ 'wound_volume_cm3': 0,
+ 'depth_percentiles': {'25': 0, '50': 0, '75': 0}
+ }
+
+ # Calibrate depth map for more accurate measurements
+ calibrated_depth_map = calibrate_depth_map(depth_map, reference_depth_mm=depth_calibration_mm)
+
+ # Get calibrated depth values for wound region
+ wound_depths_mm = calibrated_depth_map[wound_mask > 0]
+
+ # Medical depth classification
+ superficial_mask = wound_depths_mm < 2.0
+ partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0)
+ full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0)
+ deep_mask = wound_depths_mm >= 6.0
+
+ # Calculate areas
+ total_pixels = np.sum(wound_mask > 0)
+ total_area_cm2 = total_pixels * pixel_area_cm2
+
+ superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2
+ partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2
+ full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2
+ deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2
+
+ # Calculate depth statistics
+ mean_depth_mm = np.mean(wound_depths_mm)
+ max_depth_mm = np.max(wound_depths_mm)
+ depth_std_mm = np.std(wound_depths_mm)
+
+ # Calculate depth percentiles
+ depth_percentiles = {
+ '25': np.percentile(wound_depths_mm, 25),
+ '50': np.percentile(wound_depths_mm, 50),
+ '75': np.percentile(wound_depths_mm, 75)
+ }
+
+ # Calculate wound volume (approximate)
+ # Volume = area * average depth
+ wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0)
+
+ # Deep tissue ratio
+ deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0
+
+ # Calculate analysis quality metrics
+ wound_pixel_count = len(wound_depths_mm)
+ analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low"
+
+ # Calculate depth consistency (lower std dev = more consistent)
+ depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low"
+
+ return {
+ 'total_area_cm2': total_area_cm2,
+ 'superficial_area_cm2': superficial_area_cm2,
+ 'partial_thickness_area_cm2': partial_thickness_area_cm2,
+ 'full_thickness_area_cm2': full_thickness_area_cm2,
+ 'deep_area_cm2': deep_area_cm2,
+ 'mean_depth_mm': mean_depth_mm,
+ 'max_depth_mm': max_depth_mm,
+ 'depth_std_mm': depth_std_mm,
+ 'deep_ratio': deep_ratio,
+ 'wound_volume_cm3': wound_volume_cm3,
+ 'depth_percentiles': depth_percentiles,
+ 'analysis_quality': analysis_quality,
+ 'depth_consistency': depth_consistency,
+ 'wound_pixel_count': wound_pixel_count
+ }
+
+def classify_wound_severity_by_enhanced_metrics(depth_stats):
+ """
+ Enhanced wound severity classification based on medical standards
+ Uses multiple criteria: depth, area, volume, and tissue involvement
+ """
+ if depth_stats['total_area_cm2'] == 0:
+ return "Unknown"
+
+ # Extract key metrics
+ total_area = depth_stats['total_area_cm2']
+ deep_area = depth_stats['deep_area_cm2']
+ full_thickness_area = depth_stats['full_thickness_area_cm2']
+ mean_depth = depth_stats['mean_depth_mm']
+ max_depth = depth_stats['max_depth_mm']
+ wound_volume = depth_stats['wound_volume_cm3']
+ deep_ratio = depth_stats['deep_ratio']
+
+ # Medical severity classification criteria
+ severity_score = 0
+
+ # Criterion 1: Maximum depth
+ if max_depth >= 10.0:
+ severity_score += 3 # Very severe
+ elif max_depth >= 6.0:
+ severity_score += 2 # Severe
+ elif max_depth >= 4.0:
+ severity_score += 1 # Moderate
+
+ # Criterion 2: Mean depth
+ if mean_depth >= 5.0:
+ severity_score += 2
+ elif mean_depth >= 3.0:
+ severity_score += 1
+
+ # Criterion 3: Deep tissue involvement ratio
+ if deep_ratio >= 0.5:
+ severity_score += 3 # More than 50% deep tissue
+ elif deep_ratio >= 0.25:
+ severity_score += 2 # 25-50% deep tissue
+ elif deep_ratio >= 0.1:
+ severity_score += 1 # 10-25% deep tissue
+
+ # Criterion 4: Total wound area
+ if total_area >= 10.0:
+ severity_score += 2 # Large wound (>10 cm²)
+ elif total_area >= 5.0:
+ severity_score += 1 # Medium wound (5-10 cm²)
+
+ # Criterion 5: Wound volume
+ if wound_volume >= 5.0:
+ severity_score += 2 # High volume
+ elif wound_volume >= 2.0:
+ severity_score += 1 # Medium volume
+
+ # Determine severity based on total score
+ if severity_score >= 8:
+ return "Very Severe"
+ elif severity_score >= 6:
+ return "Severe"
+ elif severity_score >= 4:
+ return "Moderate"
+ elif severity_score >= 2:
+ return "Mild"
+ else:
+ return "Superficial"
+
+def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0):
+ """Enhanced wound severity analysis with medical-grade metrics"""
+ if image is None or depth_map is None or wound_mask is None:
+ return "❌ Please upload image, depth map, and wound mask."
+
+ # Convert wound mask to grayscale if needed
+ if len(wound_mask.shape) == 3:
+ wound_mask = np.mean(wound_mask, axis=2)
+
+ # Ensure depth map and mask have same dimensions
+ if depth_map.shape[:2] != wound_mask.shape[:2]:
+ # Resize mask to match depth map
+ from PIL import Image
+ mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
+ mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
+ wound_mask = np.array(mask_pil)
+
+ # Compute enhanced statistics
+ stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm)
+ severity = classify_wound_severity_by_enhanced_metrics(stats)
+
+ # Enhanced severity color coding
+ severity_color = {
+ "Superficial": "#4CAF50", # Green
+ "Mild": "#8BC34A", # Light Green
+ "Moderate": "#FF9800", # Orange
+ "Severe": "#F44336", # Red
+ "Very Severe": "#9C27B0" # Purple
+ }.get(severity, "#9E9E9E") # Gray for unknown
+
+ # Create comprehensive medical report
+ report = f"""
+
+
+ 🩹 Enhanced Wound Severity Analysis
+
+
+
+
+
+ 📏 Tissue Involvement Analysis
+
+
+
🟢 Superficial (0-2mm): {stats['superficial_area_cm2']:.2f} cm²
+
🟡 Partial Thickness (2-4mm): {stats['partial_thickness_area_cm2']:.2f} cm²
+
🟠 Full Thickness (4-6mm): {stats['full_thickness_area_cm2']:.2f} cm²
+
🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
+
📊 Total Area: {stats['total_area_cm2']:.2f} cm²
+
+
+
+
+
+ 📊 Depth Statistics
+
+
+
📏 Mean Depth: {stats['mean_depth_mm']:.1f} mm
+
📐 Max Depth: {stats['max_depth_mm']:.1f} mm
+
📊 Depth Std Dev: {stats['depth_std_mm']:.1f} mm
+
📦 Wound Volume: {stats['wound_volume_cm3']:.2f} cm³
+
🔥 Deep Tissue Ratio: {stats['deep_ratio']*100:.1f}%
+
+
+
+
+
+
+ 📈 Depth Percentiles & Quality Metrics
+
+
+
+
📊 25th Percentile: {stats['depth_percentiles']['25']:.1f} mm
+
📊 Median (50th): {stats['depth_percentiles']['50']:.1f} mm
+
📊 75th Percentile: {stats['depth_percentiles']['75']:.1f} mm
+
+
+
🔍 Analysis Quality: {stats['analysis_quality']}
+
📏 Depth Consistency: {stats['depth_consistency']}
+
📊 Data Points: {stats['wound_pixel_count']:,}
+
+
+
+
+
+
+ 🎯 Medical Severity Assessment: {severity}
+
+
+ {get_enhanced_severity_description(severity)}
+
+
+
+ """
+
+ return report
+
+def calibrate_depth_map(depth_map, reference_depth_mm=10.0):
+ """
+ Calibrate depth map to real-world measurements using reference depth
+ This helps convert normalized depth values to actual millimeters
+ """
+ if depth_map is None:
+ return depth_map
+
+ # Find the maximum depth value in the depth map
+ max_depth_value = np.max(depth_map)
+ min_depth_value = np.min(depth_map)
+
+ if max_depth_value == min_depth_value:
+ return depth_map
+
+ # Apply calibration to convert to millimeters
+ # Assuming the maximum depth in the map corresponds to reference_depth_mm
+ calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm
+
+ return calibrated_depth
+
+def get_enhanced_severity_description(severity):
+ """Get comprehensive medical description for severity level"""
+ descriptions = {
+ "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.",
+ "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.",
+ "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.",
+ "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.",
+ "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.",
+ "Unknown": "Unable to determine severity due to insufficient data or poor image quality."
+ }
+ return descriptions.get(severity, "Severity assessment unavailable.")
+
+def create_sample_wound_mask(image_shape, center=None, radius=50):
+ """Create a sample circular wound mask for testing"""
+ if center is None:
+ center = (image_shape[1] // 2, image_shape[0] // 2)
+
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ y, x = np.ogrid[:image_shape[0], :image_shape[1]]
+
+ # Create circular mask
+ dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
+ mask[dist_from_center <= radius] = 255
+
+ return mask
+
+def create_realistic_wound_mask(image_shape, method='elliptical'):
+ """Create a more realistic wound mask with irregular shapes"""
+ h, w = image_shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+
+ if method == 'elliptical':
+ # Create elliptical wound mask
+ center = (w // 2, h // 2)
+ radius_x = min(w, h) // 3
+ radius_y = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ # Add some irregularity to make it more realistic
+ ellipse = ((x - center[0])**2 / (radius_x**2) +
+ (y - center[1])**2 / (radius_y**2)) <= 1
+
+ # Add some noise and irregularity
+ noise = np.random.random((h, w)) > 0.8
+ mask = (ellipse | noise).astype(np.uint8) * 255
+
+ elif method == 'irregular':
+ # Create irregular wound mask
+ center = (w // 2, h // 2)
+ radius = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
+
+ # Add irregular extensions
+ extensions = np.zeros_like(base_circle)
+ for i in range(3):
+ angle = i * 2 * np.pi / 3
+ ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
+ ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
+ ext_radius = radius // 3
+
+ ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
+ extensions = extensions | ext_circle
+
+ mask = (base_circle | extensions).astype(np.uint8) * 255
+
+ # Apply morphological operations to smooth the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ return mask
+
+# --- Depth Estimation Functions ---
+@spaces.GPU
+def predict_depth(image):
+ return depth_model.infer_image(image)
+
+def calculate_max_points(image):
+ """Calculate maximum points based on image dimensions (3x pixel count)"""
+ if image is None:
+ return 10000 # Default value
+ h, w = image.shape[:2]
+ max_points = h * w * 3
+ # Ensure minimum and reasonable maximum values
+ return max(1000, min(max_points, 300000))
+
+def update_slider_on_image_upload(image):
+ """Update the points slider when an image is uploaded"""
+ max_points = calculate_max_points(image)
+ default_value = min(10000, max_points // 10) # 10% of max points as default
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
+ label=f"Number of 3D points (max: {max_points:,})")
+
+@spaces.GPU
+def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
+ """Create a point cloud from depth map using camera intrinsics with high detail"""
+ h, w = depth_map.shape
+
+ # Use smaller step for higher detail (reduced downsampling)
+ step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ x_cam = (x_coords - w / 2) / focal_length_x
+ y_cam = (y_coords - h / 2) / focal_length_y
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors = image_colors.reshape(-1, 3) / 255.0
+
+ # Create Open3D point cloud
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+@spaces.GPU
+def reconstruct_surface_mesh_from_point_cloud(pcd):
+ """Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
+ # Estimate and orient normals with high precision
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
+ pcd.orient_normals_consistent_tangent_plane(k=50)
+
+ # Create surface mesh with maximum detail (depth=12 for very high resolution)
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
+
+ # Return mesh without filtering low-density vertices
+ return mesh
+
+@spaces.GPU
+def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
+ """Create an enhanced 3D visualization using proper camera projection"""
+ h, w = depth_map.shape
+
+ # Downsample to avoid too many points for performance
+ step = max(1, int(np.sqrt(h * w / max_points)))
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ focal_length = 470.4 # Default focal length
+ x_cam = (x_coords - w / 2) / focal_length
+ y_cam = (y_coords - h / 2) / focal_length
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ x_flat = x_3d.flatten()
+ y_flat = y_3d.flatten()
+ z_flat = z_3d.flatten()
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors_flat = image_colors.reshape(-1, 3)
+
+ # Create 3D scatter plot with proper camera projection
+ fig = go.Figure(data=[go.Scatter3d(
+ x=x_flat,
+ y=y_flat,
+ z=z_flat,
+ mode='markers',
+ marker=dict(
+ size=1.5,
+ color=colors_flat,
+ opacity=0.9
+ ),
+ hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
+ 'Depth: %{z:.2f}
' +
+ ''
+ )])
+
+ fig.update_layout(
+ title="3D Point Cloud Visualization (Camera Projection)",
+ scene=dict(
+ xaxis_title="X (meters)",
+ yaxis_title="Y (meters)",
+ zaxis_title="Z (meters)",
+ camera=dict(
+ eye=dict(x=2.0, y=2.0, z=2.0),
+ center=dict(x=0, y=0, z=0),
+ up=dict(x=0, y=0, z=1)
+ ),
+ aspectmode='data'
+ ),
+ width=700,
+ height=600
+ )
+
+ return fig
+
+def on_depth_submit(image, num_points, focal_x, focal_y):
+ original_image = image.copy()
+
+ h, w = image.shape[:2]
+
+ # Predict depth using the model
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+
+ # Save raw 16-bit depth
+ raw_depth = Image.fromarray(depth.astype('uint16'))
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ raw_depth.save(tmp_raw_depth.name)
+
+ # Normalize and convert to grayscale for display
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ norm_depth = norm_depth.astype(np.uint8)
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ gray_depth = Image.fromarray(norm_depth)
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ gray_depth.save(tmp_gray_depth.name)
+
+ # Create point cloud
+ pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
+
+ # Reconstruct mesh from point cloud
+ mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
+
+ # Save mesh with faces as .ply
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
+ o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
+
+ # Create enhanced 3D scatter plot visualization
+ depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
+
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
+
+# --- Actual Wound Segmentation Functions ---
+def create_automatic_wound_mask(image, method='deep_learning'):
+ """
+ Automatically generate wound mask from image using the actual deep learning model
+
+ Args:
+ image: Input image (numpy array)
+ method: Segmentation method (currently only 'deep_learning' supported)
+
+ Returns:
+ mask: Binary wound mask
+ """
+ if image is None:
+ return None
+
+ # Use the actual deep learning model for segmentation
+ if method == 'deep_learning':
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+ else:
+ # Fallback to deep learning if method not recognized
+ mask, _ = segmentation_model.segment_wound(image)
+ return mask
+
+def post_process_wound_mask(mask, min_area=100):
+ """Post-process the wound mask to remove noise and small objects"""
+ if mask is None:
+ return None
+
+ # Convert to binary if needed
+ if mask.dtype != np.uint8:
+ mask = mask.astype(np.uint8)
+
+ # Apply morphological operations to clean up
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Remove small objects using OpenCV
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area >= min_area:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ # Fill holes
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
+
+ return mask_clean
+
+def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'):
+ """Analyze wound severity with automatic mask generation using actual segmentation model"""
+ if image is None or depth_map is None:
+ return "❌ Please provide both image and depth map."
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
+
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
+
+# --- Main Gradio Interface ---
+with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
+ gr.HTML("Wound Analysis & Depth Estimation System
")
+ gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
+
+ # Shared image state
+ shared_image = gr.State()
+
+ with gr.Tabs():
+ # Tab 1: Wound Classification
+ with gr.Tab("1. Wound Classification"):
+ gr.Markdown("### Step 1: Upload and classify your wound image")
+ gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
+
+ with gr.Column(scale=1):
+ wound_prediction_box = gr.HTML()
+ wound_reasoning_box = gr.HTML()
+
+ # Button to pass image to depth estimation
+ with gr.Row():
+ pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg")
+ pass_status = gr.HTML("")
+
+ wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
+ outputs=[wound_prediction_box, wound_reasoning_box])
+
+ # Store image when uploaded for classification
+ wound_image_input.change(
+ fn=lambda img: img,
+ inputs=[wound_image_input],
+ outputs=[shared_image]
+ )
+
+ # Tab 2: Depth Estimation
+ with gr.Tab("2. Depth Estimation & 3D Visualization"):
+ gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
+ gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
+
+ with gr.Row():
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
+
+ with gr.Row():
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
+ load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary")
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
+ label="Number of 3D points (upload image to update max)")
+
+ with gr.Row():
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length X (pixels)")
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length Y (pixels)")
+
+ with gr.Row():
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
+
+ # 3D Visualization
+ gr.Markdown("### 3D Point Cloud Visualization")
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
+
+ # Store depth map for severity analysis
+ depth_map_state = gr.State()
+
+ # Tab 3: Wound Severity Analysis
+ with gr.Tab("3. 🩹 Wound Severity Analysis"):
+ gr.Markdown("### Step 3: Analyze wound severity using depth maps")
+ gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
+
+ with gr.Row():
+ severity_input_image = gr.Image(label="Original Image", type='numpy')
+ severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
+
+ with gr.Row():
+ wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy')
+ severity_output = gr.HTML(label="Severity Analysis Report")
+
+ gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.")
+
+ with gr.Row():
+ auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg")
+ manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg")
+ pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
+ label="Pixel Spacing (mm/pixel)")
+ depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0,
+ label="Depth Calibration (mm)",
+ info="Adjust based on expected maximum wound depth")
+
+ gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
+ gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.")
+
+ with gr.Row():
+ # Load depth map from previous tab
+ load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary")
+
+ gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.")
+
+ # Update slider when image is uploaded
+ depth_input_image.change(
+ fn=update_slider_on_image_upload,
+ inputs=[depth_input_image],
+ outputs=[points_slider]
+ )
+
+ # Modified depth submit function to store depth map
+ def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
+ results = on_depth_submit(image, num_points, focal_x, focal_y)
+ # Extract depth map from results for severity analysis
+ depth_map = None
+ if image is not None:
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+ # Normalize depth for severity analysis
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth_map = norm_depth.astype(np.uint8)
+ return results + [depth_map]
+
+ depth_submit.click(on_depth_submit_with_state,
+ inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
+ outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
+
+ # Load depth map to severity tab and auto-generate mask
+ def load_depth_to_severity(depth_map, original_image):
+ if depth_map is None:
+ return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
+
+ # Auto-generate wound mask using segmentation model
+ if original_image is not None:
+ auto_mask, _ = segmentation_model.segment_wound(original_image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!"
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image."
+ else:
+ return depth_map, original_image, None, "✅ Depth map loaded successfully!"
+
+ load_depth_btn.click(
+ fn=load_depth_to_severity,
+ inputs=[depth_map_state, depth_input_image],
+ outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()]
+ )
+
+ # Automatic severity analysis function
+ def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+
+ # Generate automatic wound mask using the actual model
+ auto_mask = create_automatic_wound_mask(image, method='deep_learning')
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded."
+
+ # Post-process the mask with fixed minimum area
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration)
+
+ # Manual severity analysis function
+ def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing, depth_calibration):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+ if wound_mask is None:
+ return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)."
+
+ return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing, depth_calibration)
+
+ # Connect event handlers
+ auto_severity_button.click(
+ fn=run_auto_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider],
+ outputs=[severity_output]
+ )
+
+ manual_severity_button.click(
+ fn=run_manual_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider, depth_calibration_slider],
+ outputs=[severity_output]
+ )
+
+
+
+ # Auto-generate mask when image is uploaded
+ def auto_generate_mask_on_image_upload(image):
+ if image is None:
+ return None, "❌ No image uploaded."
+
+ # Generate automatic wound mask using segmentation model
+ auto_mask, _ = segmentation_model.segment_wound(image)
+ if auto_mask is not None:
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+ if processed_mask is not None and np.sum(processed_mask > 0) > 0:
+ return processed_mask, "✅ Wound mask auto-generated using deep learning model!"
+ else:
+ return None, "✅ Image uploaded but no wound detected. Try uploading a different image."
+ else:
+ return None, "✅ Image uploaded but segmentation failed. Try uploading a different image."
+
+ # Load shared image from classification tab
+ def load_shared_image(shared_img):
+ if shared_img is None:
+ return gr.Image(), "❌ No image available from classification tab"
+
+ # Convert PIL image to numpy array for depth estimation
+ if hasattr(shared_img, 'convert'):
+ # It's a PIL image, convert to numpy
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab"
+ else:
+ # Already numpy array
+ return shared_img, "✅ Image loaded from classification tab"
+
+ # Auto-generate mask when image is uploaded to severity tab
+ severity_input_image.change(
+ fn=auto_generate_mask_on_image_upload,
+ inputs=[severity_input_image],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+ load_shared_btn.click(
+ fn=load_shared_image,
+ inputs=[shared_image],
+ outputs=[depth_input_image, gr.HTML()]
+ )
+
+ # Pass image to depth tab function
+ def pass_image_to_depth(img):
+ if img is None:
+ return "❌ No image uploaded in classification tab"
+ return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
+
+ pass_to_depth_btn.click(
+ fn=pass_image_to_depth,
+ inputs=[shared_image],
+ outputs=[pass_status]
+ )
+
+if __name__ == '__main__':
+ demo.queue().launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True
+ )
\ No newline at end of file
diff --git a/temp_files/predict.py b/temp_files/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55295b68ea55691886e8ca57ad65cebf95bc08d
--- /dev/null
+++ b/temp_files/predict.py
@@ -0,0 +1,64 @@
+import cv2
+from keras.models import load_model
+from keras.utils.generic_utils import CustomObjectScope
+
+from models.unets import Unet2D
+from models.deeplab import Deeplabv3, relu6, BilinearUpsampling, DepthwiseConv2D
+from models.FCN import FCN_Vgg16_16s
+
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.BilinearUpSampling import BilinearUpSampling2D
+from utils.io.data import load_data, save_results, save_rgb_results, save_history, load_test_images, DataGen
+
+
+# settings
+input_dim_x = 224
+input_dim_y = 224
+color_space = 'rgb'
+path = './data/Medetec_foot_ulcer_224/'
+weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
+pred_save_path = '2019-12-19 01%3A53%3A15.480800/'
+
+data_gen = DataGen(path, split_ratio=0.0, x=input_dim_x, y=input_dim_y, color_space=color_space)
+x_test, test_label_filenames_list = load_test_images(path)
+
+# ### get unet model
+# unet2d = Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3)
+# model = unet2d.get_unet_model_yuanqing()
+# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
+# , custom_objects={'recall':recall,
+# 'precision':precision,
+# 'dice_coef': dice_coef,
+# 'relu6':relu6,
+# 'DepthwiseConv2D':DepthwiseConv2D,
+# 'BilinearUpsampling':BilinearUpsampling})
+
+# ### get separable unet model
+# sep_unet = Separable_Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3)
+# model, model_name = sep_unet.get_sep_unet_v2()
+# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
+# , custom_objects={'dice_coef': dice_coef,
+# 'relu6':relu6,
+# 'DepthwiseConv2D':DepthwiseConv2D,
+# 'BilinearUpsampling':BilinearUpsampling})
+
+# ### get VGG16 model
+# model, model_name = FCN_Vgg16_16s(input_shape=(input_dim_x, input_dim_y, 3))
+# with CustomObjectScope({'BilinearUpSampling2D':BilinearUpSampling2D}):
+# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name
+# , custom_objects={'dice_coef': dice_coef})
+
+# ### get mobilenetv2 model
+model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1)
+model = load_model('./training_history/' + weight_file_name
+ , custom_objects={'recall':recall,
+ 'precision':precision,
+ 'dice_coef': dice_coef,
+ 'relu6':relu6,
+ 'DepthwiseConv2D':DepthwiseConv2D,
+ 'BilinearUpsampling':BilinearUpsampling})
+
+for image_batch, label_batch in data_gen.generate_data(batch_size=len(x_test), test=True):
+ prediction = model.predict(image_batch, verbose=1)
+ save_results(prediction, 'rgb', path + 'test/predictions/' + pred_save_path, test_label_filenames_list)
+ break
diff --git a/temp_files/requirements.txt b/temp_files/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8d2f4d0195e0af4bb2dd200924865b738c3475c1
--- /dev/null
+++ b/temp_files/requirements.txt
@@ -0,0 +1,109 @@
+aiofiles
+annotated-types
+anyio
+asttokens
+attrs
+blinker
+certifi
+charset-normalizer
+click
+colorama
+comm
+ConfigArgParse
+contourpy
+cycler
+dash
+decorator
+executing
+fastapi
+fastjsonschema
+ffmpy
+filelock
+Flask
+fonttools
+fsspec
+gdown
+gradio
+gradio_client
+gradio_imageslider
+groovy
+h11
+httpcore
+httpx
+huggingface-hub
+idna
+importlib_metadata
+itsdangerous
+jedi
+Jinja2
+jsonschema
+jsonschema-specifications
+jupyter_core
+jupyterlab_widgets
+kiwisolver
+markdown-it-py
+MarkupSafe
+matplotlib
+matplotlib-inline
+mdurl
+mpmath
+narwhals
+nbformat
+nest-asyncio
+networkx
+numpy<2
+open3d
+opencv-python
+orjson
+packaging
+pandas
+parso
+pillow
+platformdirs
+plotly
+prompt_toolkit
+pure_eval
+pydantic_core
+pydub
+Pygments
+pyparsing
+python-dateutil
+python-multipart
+pytz
+PyYAML
+referencing
+requests
+retrying
+rich
+rpds-py
+ruff
+safehttpx
+scikit-image
+semantic-version
+setuptools
+shellingham
+six
+sniffio
+stack-data
+starlette
+sympy
+tensorflow<2.11
+tensorflow_hub
+tomlkit
+torch
+torchvision
+tqdm
+traitlets
+typer
+typing-inspection
+typing_extensions
+tzdata
+urllib3
+uvicorn
+wcwidth
+websockets
+Werkzeug
+wheel
+widgetsnbextension
+zipp
+pydantic==2.10.6
\ No newline at end of file
diff --git a/temp_files/run_gradio_app.py b/temp_files/run_gradio_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b3dec33e4b79363537cb8d5d9ae25a4dee9b9ba
--- /dev/null
+++ b/temp_files/run_gradio_app.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+"""
+Simple launcher for the Wound Segmentation Gradio App
+"""
+
+import sys
+import os
+
+def check_dependencies():
+ """Check if required dependencies are installed"""
+ required_packages = ['gradio', 'tensorflow', 'cv2', 'numpy']
+ missing_packages = []
+
+ for package in required_packages:
+ try:
+ if package == 'cv2':
+ import cv2
+ else:
+ __import__(package)
+ except ImportError:
+ missing_packages.append(package)
+
+ if missing_packages:
+ print("❌ Missing required packages:")
+ for package in missing_packages:
+ print(f" - {package}")
+ print("\n📦 Install missing packages with:")
+ print(" pip install -r requirements.txt")
+ return False
+
+ print("✅ All required packages are installed!")
+ return True
+
+def check_model_files():
+ """Check if model files exist"""
+ model_files = [
+ 'training_history/2025-08-07_12-30-43.hdf5',
+ 'training_history/2019-12-19 01%3A53%3A15.480800.hdf5'
+ ]
+
+ existing_models = []
+ for model_file in model_files:
+ if os.path.exists(model_file):
+ existing_models.append(model_file)
+
+ if not existing_models:
+ print("❌ No model files found!")
+ print(" Please ensure you have trained models in the training_history/ directory")
+ return False
+
+ print(f"✅ Found {len(existing_models)} model file(s):")
+ for model in existing_models:
+ print(f" - {model}")
+ return True
+
+def main():
+ """Main function to launch the Gradio app"""
+ print("🚀 Starting Wound Segmentation Gradio App...")
+ print("=" * 50)
+
+ # Check dependencies
+ if not check_dependencies():
+ sys.exit(1)
+
+ # Check model files
+ if not check_model_files():
+ sys.exit(1)
+
+ print("\n🎯 Launching Gradio interface...")
+ print(" The app will be available at: http://localhost:7860")
+ print(" Press Ctrl+C to stop the server")
+ print("=" * 50)
+
+ try:
+ # Import and run the Gradio app
+ from gradio_app import create_gradio_interface
+
+ interface = create_gradio_interface()
+ interface.launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True,
+ show_error=True
+ )
+ except KeyboardInterrupt:
+ print("\n👋 Gradio app stopped by user")
+ except Exception as e:
+ print(f"\n❌ Error launching Gradio app: {e}")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/temp_files/segmentation_app.py b/temp_files/segmentation_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e690f1a117d637234f7d6487138f3e4b584f5f9
--- /dev/null
+++ b/temp_files/segmentation_app.py
@@ -0,0 +1,222 @@
+import gradio as gr
+import cv2
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from keras.models import load_model
+from keras.utils.generic_utils import CustomObjectScope
+
+# Import custom modules
+from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.io.data import normalize
+
+class WoundSegmentationApp:
+ def __init__(self):
+ self.input_dim_x = 224
+ self.input_dim_y = 224
+ self.model = None
+ self.load_model()
+
+ def load_model(self):
+ """Load the trained wound segmentation model"""
+ try:
+ # Load the model with custom objects
+ weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Model loaded successfully from {model_path}")
+ except Exception as e:
+ print(f"Error loading model: {e}")
+ # Fallback to the older model if the newer one fails
+ try:
+ weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5'
+ model_path = f'./training_history/{weight_file_name}'
+
+ self.model = load_model(model_path,
+ custom_objects={
+ 'recall': recall,
+ 'precision': precision,
+ 'dice_coef': dice_coef,
+ 'relu6': relu6,
+ 'DepthwiseConv2D': DepthwiseConv2D,
+ 'BilinearUpsampling': BilinearUpsampling
+ })
+ print(f"Model loaded successfully from {model_path}")
+ except Exception as e2:
+ print(f"Error loading fallback model: {e2}")
+ self.model = None
+
+ def preprocess_image(self, image):
+ """Preprocess the uploaded image for model input"""
+ if image is None:
+ return None
+
+ # Convert to RGB if needed
+ if len(image.shape) == 3 and image.shape[2] == 3:
+ # Convert BGR to RGB if needed
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Resize to model input size
+ image = cv2.resize(image, (self.input_dim_x, self.input_dim_y))
+
+ # Normalize the image
+ image = image.astype(np.float32) / 255.0
+
+ # Add batch dimension
+ image = np.expand_dims(image, axis=0)
+
+ return image
+
+ def postprocess_prediction(self, prediction):
+ """Postprocess the model prediction"""
+ # Remove batch dimension
+ prediction = prediction[0]
+
+ # Apply threshold to get binary mask
+ threshold = 0.5
+ binary_mask = (prediction > threshold).astype(np.uint8) * 255
+
+ # Convert to 3-channel image for visualization
+ mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB)
+
+ return mask_rgb
+
+ def segment_wound(self, input_image):
+ """Main function to segment wound from uploaded image"""
+ if self.model is None:
+ return None, "Error: Model not loaded. Please check the model files."
+
+ if input_image is None:
+ return None, "Please upload an image."
+
+ try:
+ # Preprocess the image
+ processed_image = self.preprocess_image(input_image)
+
+ if processed_image is None:
+ return None, "Error processing image."
+
+ # Make prediction
+ prediction = self.model.predict(processed_image, verbose=0)
+
+ # Postprocess the prediction
+ segmented_mask = self.postprocess_prediction(prediction)
+
+ # Create overlay image (original image with segmentation overlay)
+ original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y))
+ if len(original_resized.shape) == 3:
+ original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR)
+
+ # Create overlay with red segmentation
+ overlay = original_resized.copy()
+ mask_red = np.zeros_like(original_resized)
+ mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel
+
+ # Blend overlay with original image
+ alpha = 0.6
+ overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0)
+
+ return segmented_mask, overlay
+
+ except Exception as e:
+ return None, f"Error during segmentation: {str(e)}"
+
+def create_gradio_interface():
+ """Create and return the Gradio interface"""
+
+ # Initialize the app
+ app = WoundSegmentationApp()
+
+ # Define the interface
+ with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface:
+ gr.Markdown(
+ """
+ # 🩹 Wound Segmentation Tool
+
+ Upload an image of a wound to get an automated segmentation mask.
+ The model will identify and highlight the wound area in the image.
+
+ **Instructions:**
+ 1. Upload an image of a wound
+ 2. Click "Segment Wound" to process the image
+ 3. View the segmentation mask and overlay results
+ """
+ )
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(
+ label="Upload Wound Image",
+ type="numpy",
+ height=400
+ )
+
+ segment_btn = gr.Button(
+ "🔍 Segment Wound",
+ variant="primary",
+ size="lg"
+ )
+
+ with gr.Column():
+ mask_output = gr.Image(
+ label="Segmentation Mask",
+ height=400
+ )
+
+ overlay_output = gr.Image(
+ label="Overlay Result",
+ height=400
+ )
+
+ # Status message
+ status_msg = gr.Textbox(
+ label="Status",
+ interactive=False,
+ placeholder="Ready to process images..."
+ )
+
+ # Example images
+ gr.Markdown("### 📸 Example Images")
+ gr.Markdown("You can test the tool with wound images from the dataset.")
+
+ # Connect the button to the segmentation function
+ def process_image(image):
+ mask, overlay = app.segment_wound(image)
+ if mask is None:
+ return None, None, overlay # overlay contains error message
+ return mask, overlay, "Segmentation completed successfully!"
+
+ segment_btn.click(
+ fn=process_image,
+ inputs=[input_image],
+ outputs=[mask_output, overlay_output, status_msg]
+ )
+
+ # Auto-process when image is uploaded
+ input_image.change(
+ fn=process_image,
+ inputs=[input_image],
+ outputs=[mask_output, overlay_output, status_msg]
+ )
+
+ return interface
+
+if __name__ == "__main__":
+ # Create and launch the interface
+ interface = create_gradio_interface()
+ interface.launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True,
+ show_error=True
+ )
\ No newline at end of file
diff --git a/temp_files/test1.txt b/temp_files/test1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c37b7fe936a455becac41a6bd7eb58e3275d229f
--- /dev/null
+++ b/temp_files/test1.txt
@@ -0,0 +1,843 @@
+import glob
+import gradio as gr
+import matplotlib
+import numpy as np
+from PIL import Image
+import torch
+import tempfile
+from gradio_imageslider import ImageSlider
+import plotly.graph_objects as go
+import plotly.express as px
+import open3d as o3d
+from depth_anything_v2.dpt import DepthAnythingV2
+import os
+import tensorflow as tf
+from tensorflow.keras.models import load_model
+from tensorflow.keras.preprocessing import image as keras_image
+import base64
+from io import BytesIO
+import gdown
+import spaces
+import cv2
+from skimage import filters, morphology, measure
+from skimage.segmentation import clear_border
+
+# --- LINEAR INITIALIZATION - NO MODULAR FUNCTIONS ---
+print("Starting linear initialization for ZeroGPU compatibility...")
+
+# Define path and file ID
+checkpoint_dir = "checkpoints"
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
+gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
+
+# Download if not already present
+if not os.path.exists(model_file):
+ print("Downloading model from Google Drive...")
+ gdown.download(gdrive_url, model_file, quiet=False)
+
+# --- TensorFlow: Check GPU Availability ---
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+ print("TensorFlow is using GPU")
+else:
+ print("TensorFlow is using CPU")
+
+# --- Load Wound Classification Model and Class Labels ---
+wound_model = load_model("/home/user/app/keras_model.h5")
+with open("/home/user/app/labels.txt", "r") as f:
+ class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
+
+# --- PyTorch: Set Device and Load Depth Model ---
+print("Initializing PyTorch device...")
+map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
+print(f"Using PyTorch device: {map_device}")
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+}
+encoder = 'vitl'
+depth_model = DepthAnythingV2(**model_configs[encoder])
+state_dict = torch.load(
+ f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth',
+ map_location=map_device
+)
+depth_model.load_state_dict(state_dict)
+depth_model = depth_model.to(map_device).eval()
+
+# --- Custom CSS for unified dark theme ---
+css = """
+.gradio-container {
+ font-family: 'Segoe UI', sans-serif;
+ background-color: #121212;
+ color: #ffffff;
+ padding: 20px;
+}
+.gr-button {
+ background-color: #2c3e50;
+ color: white;
+ border-radius: 10px;
+}
+.gr-button:hover {
+ background-color: #34495e;
+}
+.gr-html, .gr-html div {
+ white-space: normal !important;
+ overflow: visible !important;
+ text-overflow: unset !important;
+ word-break: break-word !important;
+}
+#img-display-container {
+ max-height: 100vh;
+}
+#img-display-input {
+ max-height: 80vh;
+}
+#img-display-output {
+ max-height: 80vh;
+}
+#download {
+ height: 62px;
+}
+h1 {
+ text-align: center;
+ font-size: 3rem;
+ font-weight: bold;
+ margin: 2rem 0;
+ color: #ffffff;
+}
+h2 {
+ color: #ffffff;
+ text-align: center;
+ margin: 1rem 0;
+}
+.gr-tabs {
+ background-color: #1e1e1e;
+ border-radius: 10px;
+ padding: 10px;
+}
+.gr-tab-nav {
+ background-color: #2c3e50;
+ border-radius: 8px;
+}
+.gr-tab-nav button {
+ color: #ffffff !important;
+}
+.gr-tab-nav button.selected {
+ background-color: #34495e !important;
+}
+"""
+
+# --- LINEAR FUNCTION DEFINITIONS (NO MODULAR CALLS) ---
+
+# Wound Classification Functions
+def preprocess_input(img):
+ img = img.resize((224, 224))
+ arr = keras_image.img_to_array(img)
+ arr = arr / 255.0
+ return np.expand_dims(arr, axis=0)
+
+def get_reasoning_from_gemini(img, prediction):
+ try:
+ explanations = {
+ "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
+ "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
+ "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
+ "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
+ "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
+ }
+ return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
+ except Exception as e:
+ return f"(Reasoning unavailable: {str(e)})"
+
+@spaces.GPU
+def classify_wound_image(img):
+ if img is None:
+ return "No image provided
", ""
+
+ img_array = preprocess_input(img)
+ predictions = wound_model.predict(img_array, verbose=0)[0]
+ pred_idx = int(np.argmax(predictions))
+ pred_class = class_labels[pred_idx]
+
+ reasoning_text = get_reasoning_from_gemini(img, pred_class)
+
+ predicted_card = f"""
+
+
+ Predicted Wound Type
+
+
+ {pred_class}
+
+
+ """
+
+ reasoning_card = f"""
+
+
+ Reasoning
+
+
+ {reasoning_text}
+
+
+ """
+
+ return predicted_card, reasoning_card
+
+# Depth Estimation Functions
+@spaces.GPU
+def predict_depth(image):
+ return depth_model.infer_image(image)
+
+def calculate_max_points(image):
+ if image is None:
+ return 10000
+ h, w = image.shape[:2]
+ max_points = h * w * 3
+ return max(1000, min(max_points, 300000))
+
+def update_slider_on_image_upload(image):
+ max_points = calculate_max_points(image)
+ default_value = min(10000, max_points // 10)
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
+ label=f"Number of 3D points (max: {max_points:,})")
+
+@spaces.GPU
+def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
+ h, w = depth_map.shape
+ step = max(1, int(np.sqrt(h * w / max_points) * 0.5))
+
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+ x_cam = (x_coords - w / 2) / focal_length_x
+ y_cam = (y_coords - h / 2) / focal_length_y
+ depth_values = depth_map[::step, ::step]
+
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
+ image_colors = image[::step, ::step, :]
+ colors = image_colors.reshape(-1, 3) / 255.0
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+@spaces.GPU
+def reconstruct_surface_mesh_from_point_cloud(pcd):
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
+ pcd.orient_normals_consistent_tangent_plane(k=50)
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
+ return mesh
+
+@spaces.GPU
+def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
+ h, w = depth_map.shape
+ step = max(1, int(np.sqrt(h * w / max_points)))
+
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+ focal_length = 470.4
+ x_cam = (x_coords - w / 2) / focal_length
+ y_cam = (y_coords - h / 2) / focal_length
+ depth_values = depth_map[::step, ::step]
+
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ x_flat = x_3d.flatten()
+ y_flat = y_3d.flatten()
+ z_flat = z_3d.flatten()
+
+ image_colors = image[::step, ::step, :]
+ colors_flat = image_colors.reshape(-1, 3)
+
+ fig = go.Figure(data=[go.Scatter3d(
+ x=x_flat,
+ y=y_flat,
+ z=z_flat,
+ mode='markers',
+ marker=dict(
+ size=1.5,
+ color=colors_flat,
+ opacity=0.9
+ ),
+ hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
+ 'Depth: %{z:.2f}
' +
+ ''
+ )])
+
+ fig.update_layout(
+ title="3D Point Cloud Visualization (Camera Projection)",
+ scene=dict(
+ xaxis_title="X (meters)",
+ yaxis_title="Y (meters)",
+ zaxis_title="Z (meters)",
+ camera=dict(
+ eye=dict(x=2.0, y=2.0, z=2.0),
+ center=dict(x=0, y=0, z=0),
+ up=dict(x=0, y=0, z=1)
+ ),
+ aspectmode='data'
+ ),
+ width=700,
+ height=600
+ )
+
+ return fig
+
+def on_depth_submit(image, num_points, focal_x, focal_y):
+ original_image = image.copy()
+ h, w = image.shape[:2]
+
+ depth = predict_depth(image[:, :, ::-1])
+
+ raw_depth = Image.fromarray(depth.astype('uint16'))
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ raw_depth.save(tmp_raw_depth.name)
+
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ norm_depth = norm_depth.astype(np.uint8)
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ gray_depth = Image.fromarray(norm_depth)
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ gray_depth.save(tmp_gray_depth.name)
+
+ pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
+ mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
+
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
+ o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
+
+ depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
+
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
+
+# Wound Severity Analysis Functions
+@spaces.GPU
+def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
+ pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
+ wound_mask = (mask > 127)
+ wound_depths = depth_map[wound_mask]
+ total_area = np.sum(wound_mask) * pixel_area_cm2
+
+ shallow = wound_depths < 3
+ moderate = (wound_depths >= 3) & (wound_depths < 6)
+ deep = wound_depths >= 6
+
+ shallow_area = np.sum(shallow) * pixel_area_cm2
+ moderate_area = np.sum(moderate) * pixel_area_cm2
+ deep_area = np.sum(deep) * pixel_area_cm2
+ deep_ratio = deep_area / total_area if total_area > 0 else 0
+
+ return {
+ 'total_area_cm2': total_area,
+ 'shallow_area_cm2': shallow_area,
+ 'moderate_area_cm2': moderate_area,
+ 'deep_area_cm2': deep_area,
+ 'deep_ratio': deep_ratio,
+ 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
+ }
+
+def classify_wound_severity_by_area(depth_stats):
+ total = depth_stats['total_area_cm2']
+ deep = depth_stats['deep_area_cm2']
+ moderate = depth_stats['moderate_area_cm2']
+
+ if total == 0:
+ return "Unknown"
+
+ if deep > 2 or (deep / total) > 0.3:
+ return "Severe"
+ elif moderate > 1.5 or (moderate / total) > 0.4:
+ return "Moderate"
+ else:
+ return "Mild"
+
+def get_severity_description(severity):
+ descriptions = {
+ "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
+ "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
+ "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
+ "Unknown": "Unable to determine severity due to insufficient data."
+ }
+ return descriptions.get(severity, "Severity assessment unavailable.")
+
+def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
+ if image is None or depth_map is None or wound_mask is None:
+ return "❌ Please upload image, depth map, and wound mask."
+
+ if len(wound_mask.shape) == 3:
+ wound_mask = np.mean(wound_mask, axis=2)
+
+ if depth_map.shape[:2] != wound_mask.shape[:2]:
+ from PIL import Image
+ mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
+ mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
+ wound_mask = np.array(mask_pil)
+
+ stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
+ severity = classify_wound_severity_by_area(stats)
+
+ severity_color = {
+ "Mild": "#4CAF50",
+ "Moderate": "#FF9800",
+ "Severe": "#F44336"
+ }.get(severity, "#9E9E9E")
+
+ report = f"""
+
+
+ 🩹 Wound Severity Analysis
+
+
+
+
+
+ 📏 Area Measurements
+
+
+
🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
+
🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
+
🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
+
🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
+
+
+
+
+
+ 📊 Depth Analysis
+
+
+
🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
+
📏 Max Depth: {stats['max_depth']:.1f} mm
+
⚡ Pixel Spacing: {pixel_spacing_mm} mm
+
+
+
+
+
+
+ 🎯 Predicted Severity: {severity}
+
+
+ {get_severity_description(severity)}
+
+
+
+ """
+
+ return report
+
+# Automatic Wound Mask Generation Functions
+def create_automatic_wound_mask(image, method='adaptive'):
+ if image is None:
+ return None
+
+ if len(image.shape) == 3:
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ else:
+ gray = image.copy()
+
+ if method == 'adaptive':
+ mask = adaptive_threshold_segmentation(gray)
+ elif method == 'otsu':
+ mask = otsu_threshold_segmentation(gray)
+ elif method == 'color':
+ mask = color_based_segmentation(image)
+ elif method == 'combined':
+ mask = combined_segmentation(image, gray)
+ else:
+ mask = adaptive_threshold_segmentation(gray)
+
+ return mask
+
+def adaptive_threshold_segmentation(gray):
+ blurred = cv2.GaussianBlur(gray, (15, 15), 0)
+ thresh = cv2.adaptiveThreshold(
+ blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5
+ )
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
+ mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 1000:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def otsu_threshold_segmentation(gray):
+ blurred = cv2.GaussianBlur(gray, (15, 15), 0)
+ _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 800:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def color_based_segmentation(image):
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+
+ lower_red1 = np.array([0, 30, 30])
+ upper_red1 = np.array([15, 255, 255])
+ lower_red2 = np.array([160, 30, 30])
+ upper_red2 = np.array([180, 255, 255])
+
+ mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
+ mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
+ red_mask = mask1 + mask2
+
+ lower_yellow = np.array([15, 30, 30])
+ upper_yellow = np.array([35, 255, 255])
+ yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
+
+ lower_brown = np.array([10, 50, 20])
+ upper_brown = np.array([20, 255, 200])
+ brown_mask = cv2.inRange(hsv, lower_brown, upper_brown)
+
+ color_mask = red_mask + yellow_mask + brown_mask
+
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
+
+ contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(color_mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 600:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def combined_segmentation(image, gray):
+ adaptive_mask = adaptive_threshold_segmentation(gray)
+ otsu_mask = otsu_threshold_segmentation(gray)
+ color_mask = color_based_segmentation(image)
+
+ combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask)
+ combined_mask = cv2.bitwise_or(combined_mask, color_mask)
+
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
+ combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
+
+ contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(combined_mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 500:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ if np.sum(mask_clean) == 0:
+ mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical')
+
+ return mask_clean
+
+def create_realistic_wound_mask(image_shape, method='elliptical'):
+ h, w = image_shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+
+ if method == 'elliptical':
+ center = (w // 2, h // 2)
+ radius_x = min(w, h) // 3
+ radius_y = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ ellipse = ((x - center[0])**2 / (radius_x**2) +
+ (y - center[1])**2 / (radius_y**2)) <= 1
+
+ noise = np.random.random((h, w)) > 0.8
+ mask = (ellipse | noise).astype(np.uint8) * 255
+
+ elif method == 'irregular':
+ center = (w // 2, h // 2)
+ radius = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
+
+ extensions = np.zeros_like(base_circle)
+ for i in range(3):
+ angle = i * 2 * np.pi / 3
+ ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
+ ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
+ ext_radius = radius // 3
+
+ ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
+ extensions = extensions | ext_circle
+
+ mask = (base_circle | extensions).astype(np.uint8) * 255
+
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ return mask
+
+def post_process_wound_mask(mask, min_area=100):
+ if mask is None:
+ return None
+
+ if mask.dtype != np.uint8:
+ mask = mask.astype(np.uint8)
+
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area >= min_area:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
+
+ return mask_clean
+
+def create_sample_wound_mask(image_shape, center=None, radius=50):
+ if center is None:
+ center = (image_shape[1] // 2, image_shape[0] // 2)
+
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ y, x = np.ogrid[:image_shape[0], :image_shape[1]]
+
+ dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
+ mask[dist_from_center <= radius] = 255
+
+ return mask
+
+# --- MAIN GRADIO INTERFACE (LINEAR EXECUTION) ---
+print("Creating Gradio interface...")
+
+with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
+ gr.HTML("Wound Analysis & Depth Estimation System
")
+ gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
+
+ shared_image = gr.State()
+
+ with gr.Tabs():
+ # Tab 1: Wound Classification
+ with gr.Tab("1. Wound Classification"):
+ gr.Markdown("### Step 1: Upload and classify your wound image")
+ gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
+
+ with gr.Column(scale=1):
+ wound_prediction_box = gr.HTML()
+ wound_reasoning_box = gr.HTML()
+
+ with gr.Row():
+ pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg")
+ pass_status = gr.HTML("")
+
+ wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
+ outputs=[wound_prediction_box, wound_reasoning_box])
+
+ wound_image_input.change(
+ fn=lambda img: img,
+ inputs=[wound_image_input],
+ outputs=[shared_image]
+ )
+
+ # Tab 2: Depth Estimation
+ with gr.Tab("2. Depth Estimation & 3D Visualization"):
+ gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
+ gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
+
+ with gr.Row():
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
+
+ with gr.Row():
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
+ load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary")
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
+ label="Number of 3D points (upload image to update max)")
+
+ with gr.Row():
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length X (pixels)")
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length Y (pixels)")
+
+ with gr.Row():
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
+
+ gr.Markdown("### 3D Point Cloud Visualization")
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
+
+ depth_map_state = gr.State()
+
+ # Tab 3: Wound Severity Analysis
+ with gr.Tab("3. 🩹 Wound Severity Analysis"):
+ gr.Markdown("### Step 3: Analyze wound severity using depth maps")
+ gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
+
+ with gr.Row():
+ severity_input_image = gr.Image(label="Original Image", type='numpy')
+ severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
+
+ with gr.Row():
+ wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy')
+ severity_output = gr.HTML(label="Severity Analysis Report")
+
+ gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.")
+
+ with gr.Row():
+ auto_severity_button = gr.Button("🤖 Auto-Analyze Severity", variant="primary", size="lg")
+ manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg")
+ pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
+ label="Pixel Spacing (mm/pixel)")
+
+ gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
+
+ with gr.Row():
+ segmentation_method = gr.Dropdown(
+ choices=["combined", "adaptive", "otsu", "color"],
+ value="combined",
+ label="Segmentation Method",
+ info="Choose automatic segmentation method"
+ )
+ min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100,
+ label="Minimum Area (pixels)",
+ info="Minimum wound area to detect")
+
+ with gr.Row():
+ load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary")
+ sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary")
+ realistic_mask_btn = gr.Button("🏥 Generate Realistic Mask", variant="secondary")
+ preview_mask_btn = gr.Button("👁️ Preview Auto Mask", variant="secondary")
+
+ gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.")
+
+ # Event handlers
+ def generate_sample_mask(image):
+ if image is None:
+ return None, "❌ Please load an image first."
+ sample_mask = create_sample_wound_mask(image.shape)
+ return sample_mask, "✅ Sample circular wound mask generated!"
+
+ def generate_realistic_mask(image):
+ if image is None:
+ return None, "❌ Please load an image first."
+ realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical')
+ return realistic_mask, "✅ Realistic elliptical wound mask generated!"
+
+ def load_depth_to_severity(depth_map, original_image):
+ if depth_map is None:
+ return None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
+ return depth_map, original_image, "✅ Depth map loaded successfully!"
+
+ def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+
+ def post_process_with_area(mask):
+ return post_process_wound_mask(mask, min_area=min_area)
+
+ auto_mask = create_automatic_wound_mask(image, method=seg_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask."
+
+ processed_mask = post_process_with_area(auto_mask)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask."
+
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
+
+ def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+ if wound_mask is None:
+ return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)."
+ return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
+
+ def preview_auto_mask(image, seg_method, min_area):
+ if image is None:
+ return None, "❌ Please load an image first."
+ auto_mask = create_automatic_wound_mask(image, method=seg_method)
+ if auto_mask is None:
+ return None, "❌ Failed to generate automatic wound mask."
+ processed_mask = post_process_wound_mask(auto_mask, min_area=min_area)
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return None, "❌ No wound region detected. Try adjusting parameters."
+ return processed_mask, f"✅ Auto mask generated using {seg_method} method!"
+
+ def load_shared_image(shared_img):
+ if shared_img is None:
+ return gr.Image(), "❌ No image available from classification tab"
+ if hasattr(shared_img, 'convert'):
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab"
+ else:
+ return shared_img, "✅ Image loaded from classification tab"
+
+ def pass_image_to_depth(img):
+ if img is None:
+ return "❌ No image uploaded in classification tab"
+ return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
+
+ def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
+ results = on_depth_submit(image, num_points, focal_x, focal_y)
+ depth_map = None
+ if image is not None:
+ depth = predict_depth(image[:, :, ::-1])
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth_map = norm_depth.astype(np.uint8)
+ return results + [depth_map]
+
+ # Connect all event handlers
+ sample_mask_btn.click(fn=generate_sample_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()])
+ realistic_mask_btn.click(fn=generate_realistic_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()])
+ depth_input_image.change(fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[points_slider])
+ depth_submit.click(on_depth_submit_with_state, inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
+ load_depth_btn.click(fn=load_depth_to_severity, inputs=[depth_map_state, depth_input_image], outputs=[severity_depth_map, severity_input_image, gr.HTML()])
+ auto_severity_button.click(fn=run_auto_severity_analysis, inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, segmentation_method, min_area_slider], outputs=[severity_output])
+ manual_severity_button.click(fn=run_manual_severity_analysis, inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], outputs=[severity_output])
+ preview_mask_btn.click(fn=preview_auto_mask, inputs=[severity_input_image, segmentation_method, min_area_slider], outputs=[wound_mask_input, gr.HTML()])
+ load_shared_btn.click(fn=load_shared_image, inputs=[shared_image], outputs=[depth_input_image, gr.HTML()])
+ pass_to_depth_btn.click(fn=pass_image_to_depth, inputs=[shared_image], outputs=[pass_status])
+
+print("Gradio interface created successfully!")
+
+if __name__ == '__main__':
+ print("Launching app...")
+ demo.queue().launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True
+ )
diff --git a/temp_files/test2.txt b/temp_files/test2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bd3ddd83a996c7875d6db6b80321e70abb10495f
--- /dev/null
+++ b/temp_files/test2.txt
@@ -0,0 +1,1063 @@
+import glob
+import gradio as gr
+import matplotlib
+import numpy as np
+from PIL import Image
+import torch
+import tempfile
+from gradio_imageslider import ImageSlider
+import plotly.graph_objects as go
+import plotly.express as px
+import open3d as o3d
+from depth_anything_v2.dpt import DepthAnythingV2
+import os
+import tensorflow as tf
+from tensorflow.keras.models import load_model
+from tensorflow.keras.preprocessing import image as keras_image
+import base64
+from io import BytesIO
+import gdown
+import spaces
+
+# Define path and file ID
+checkpoint_dir = "checkpoints"
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth")
+gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5"
+
+# Download if not already present
+if not os.path.exists(model_file):
+ print("Downloading model from Google Drive...")
+ gdown.download(gdrive_url, model_file, quiet=False)
+
+# --- TensorFlow: Check GPU Availability ---
+gpus = tf.config.list_physical_devices('GPU')
+if gpus:
+ print("TensorFlow is using GPU")
+else:
+ print("TensorFlow is using CPU")
+
+# --- Load Wound Classification Model and Class Labels ---
+wound_model = load_model("/home/user/app/keras_model.h5")
+with open("/home/user/app/labels.txt", "r") as f:
+ class_labels = [line.strip().split(maxsplit=1)[1] for line in f]
+
+# --- PyTorch: Set Device and Load Depth Model ---
+map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")
+print(f"Using PyTorch device: {map_device}")
+
+model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
+}
+encoder = 'vitl'
+depth_model = DepthAnythingV2(**model_configs[encoder])
+state_dict = torch.load(
+ f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth',
+ map_location=map_device
+)
+depth_model.load_state_dict(state_dict)
+depth_model = depth_model.to(map_device).eval()
+
+
+# --- Custom CSS for unified dark theme ---
+css = """
+.gradio-container {
+ font-family: 'Segoe UI', sans-serif;
+ background-color: #121212;
+ color: #ffffff;
+ padding: 20px;
+}
+.gr-button {
+ background-color: #2c3e50;
+ color: white;
+ border-radius: 10px;
+}
+.gr-button:hover {
+ background-color: #34495e;
+}
+.gr-html, .gr-html div {
+ white-space: normal !important;
+ overflow: visible !important;
+ text-overflow: unset !important;
+ word-break: break-word !important;
+}
+#img-display-container {
+ max-height: 100vh;
+}
+#img-display-input {
+ max-height: 80vh;
+}
+#img-display-output {
+ max-height: 80vh;
+}
+#download {
+ height: 62px;
+}
+h1 {
+ text-align: center;
+ font-size: 3rem;
+ font-weight: bold;
+ margin: 2rem 0;
+ color: #ffffff;
+}
+h2 {
+ color: #ffffff;
+ text-align: center;
+ margin: 1rem 0;
+}
+.gr-tabs {
+ background-color: #1e1e1e;
+ border-radius: 10px;
+ padding: 10px;
+}
+.gr-tab-nav {
+ background-color: #2c3e50;
+ border-radius: 8px;
+}
+.gr-tab-nav button {
+ color: #ffffff !important;
+}
+.gr-tab-nav button.selected {
+ background-color: #34495e !important;
+}
+"""
+
+# --- Wound Classification Functions ---
+def preprocess_input(img):
+ img = img.resize((224, 224))
+ arr = keras_image.img_to_array(img)
+ arr = arr / 255.0
+ return np.expand_dims(arr, axis=0)
+
+def get_reasoning_from_gemini(img, prediction):
+ try:
+ # For now, return a simple explanation without Gemini API to avoid typing issues
+ # In production, you would implement the proper Gemini API call here
+ explanations = {
+ "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.",
+ "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.",
+ "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.",
+ "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.",
+ "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues."
+ }
+
+ return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.")
+
+ except Exception as e:
+ return f"(Reasoning unavailable: {str(e)})"
+
+@spaces.GPU
+def classify_wound_image(img):
+ if img is None:
+ return "No image provided
", ""
+
+ img_array = preprocess_input(img)
+ predictions = wound_model.predict(img_array, verbose=0)[0]
+ pred_idx = int(np.argmax(predictions))
+ pred_class = class_labels[pred_idx]
+
+ # Get reasoning from Gemini
+ reasoning_text = get_reasoning_from_gemini(img, pred_class)
+
+ # Prediction Card
+ predicted_card = f"""
+
+
+ Predicted Wound Type
+
+
+ {pred_class}
+
+
+ """
+
+ # Reasoning Card
+ reasoning_card = f"""
+
+
+ Reasoning
+
+
+ {reasoning_text}
+
+
+ """
+
+ return predicted_card, reasoning_card
+
+# --- Wound Severity Estimation Functions ---
+@spaces.GPU
+def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5):
+ """Compute area statistics for different depth regions"""
+ pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2
+
+ # Extract only wound region
+ wound_mask = (mask > 127)
+ wound_depths = depth_map[wound_mask]
+ total_area = np.sum(wound_mask) * pixel_area_cm2
+
+ # Categorize depth regions
+ shallow = wound_depths < 3
+ moderate = (wound_depths >= 3) & (wound_depths < 6)
+ deep = wound_depths >= 6
+
+ shallow_area = np.sum(shallow) * pixel_area_cm2
+ moderate_area = np.sum(moderate) * pixel_area_cm2
+ deep_area = np.sum(deep) * pixel_area_cm2
+
+ deep_ratio = deep_area / total_area if total_area > 0 else 0
+
+ return {
+ 'total_area_cm2': total_area,
+ 'shallow_area_cm2': shallow_area,
+ 'moderate_area_cm2': moderate_area,
+ 'deep_area_cm2': deep_area,
+ 'deep_ratio': deep_ratio,
+ 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0
+ }
+
+def classify_wound_severity_by_area(depth_stats):
+ """Classify wound severity based on area and depth distribution"""
+ total = depth_stats['total_area_cm2']
+ deep = depth_stats['deep_area_cm2']
+ moderate = depth_stats['moderate_area_cm2']
+
+ if total == 0:
+ return "Unknown"
+
+ # Severity classification rules
+ if deep > 2 or (deep / total) > 0.3:
+ return "Severe"
+ elif moderate > 1.5 or (moderate / total) > 0.4:
+ return "Moderate"
+ else:
+ return "Mild"
+
+def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5):
+ """Analyze wound severity from depth map and wound mask"""
+ if image is None or depth_map is None or wound_mask is None:
+ return "❌ Please upload image, depth map, and wound mask."
+
+ # Convert wound mask to grayscale if needed
+ if len(wound_mask.shape) == 3:
+ wound_mask = np.mean(wound_mask, axis=2)
+
+ # Ensure depth map and mask have same dimensions
+ if depth_map.shape[:2] != wound_mask.shape[:2]:
+ # Resize mask to match depth map
+ from PIL import Image
+ mask_pil = Image.fromarray(wound_mask.astype(np.uint8))
+ mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0]))
+ wound_mask = np.array(mask_pil)
+
+ # Compute statistics
+ stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm)
+ severity = classify_wound_severity_by_area(stats)
+
+ # Create severity report with color coding
+ severity_color = {
+ "Mild": "#4CAF50", # Green
+ "Moderate": "#FF9800", # Orange
+ "Severe": "#F44336" # Red
+ }.get(severity, "#9E9E9E") # Gray for unknown
+
+ report = f"""
+
+
+ 🩹 Wound Severity Analysis
+
+
+
+
+
+ 📏 Area Measurements
+
+
+
🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
+
🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
+
🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
+
🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
+
+
+
+
+
+ 📊 Depth Analysis
+
+
+
🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
+
📏 Max Depth: {stats['max_depth']:.1f} mm
+
⚡ Pixel Spacing: {pixel_spacing_mm} mm
+
+
+
+
+
+
+ 🎯 Predicted Severity: {severity}
+
+
+ {get_severity_description(severity)}
+
+
+
+ """
+
+ return report
+
+def get_severity_description(severity):
+ """Get description for severity level"""
+ descriptions = {
+ "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.",
+ "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.",
+ "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.",
+ "Unknown": "Unable to determine severity due to insufficient data."
+ }
+ return descriptions.get(severity, "Severity assessment unavailable.")
+
+def create_sample_wound_mask(image_shape, center=None, radius=50):
+ """Create a sample circular wound mask for testing"""
+ if center is None:
+ center = (image_shape[1] // 2, image_shape[0] // 2)
+
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ y, x = np.ogrid[:image_shape[0], :image_shape[1]]
+
+ # Create circular mask
+ dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
+ mask[dist_from_center <= radius] = 255
+
+ return mask
+
+def create_realistic_wound_mask(image_shape, method='elliptical'):
+ """Create a more realistic wound mask with irregular shapes"""
+ h, w = image_shape[:2]
+ mask = np.zeros((h, w), dtype=np.uint8)
+
+ if method == 'elliptical':
+ # Create elliptical wound mask
+ center = (w // 2, h // 2)
+ radius_x = min(w, h) // 3
+ radius_y = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ # Add some irregularity to make it more realistic
+ ellipse = ((x - center[0])**2 / (radius_x**2) +
+ (y - center[1])**2 / (radius_y**2)) <= 1
+
+ # Add some noise and irregularity
+ noise = np.random.random((h, w)) > 0.8
+ mask = (ellipse | noise).astype(np.uint8) * 255
+
+ elif method == 'irregular':
+ # Create irregular wound mask
+ center = (w // 2, h // 2)
+ radius = min(w, h) // 4
+
+ y, x = np.ogrid[:h, :w]
+ base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius
+
+ # Add irregular extensions
+ extensions = np.zeros_like(base_circle)
+ for i in range(3):
+ angle = i * 2 * np.pi / 3
+ ext_x = int(center[0] + radius * 0.8 * np.cos(angle))
+ ext_y = int(center[1] + radius * 0.8 * np.sin(angle))
+ ext_radius = radius // 3
+
+ ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius
+ extensions = extensions | ext_circle
+
+ mask = (base_circle | extensions).astype(np.uint8) * 255
+
+ # Apply morphological operations to smooth the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ return mask
+
+# --- Depth Estimation Functions ---
+@spaces.GPU
+def predict_depth(image):
+ return depth_model.infer_image(image)
+
+def calculate_max_points(image):
+ """Calculate maximum points based on image dimensions (3x pixel count)"""
+ if image is None:
+ return 10000 # Default value
+ h, w = image.shape[:2]
+ max_points = h * w * 3
+ # Ensure minimum and reasonable maximum values
+ return max(1000, min(max_points, 300000))
+
+def update_slider_on_image_upload(image):
+ """Update the points slider when an image is uploaded"""
+ max_points = calculate_max_points(image)
+ default_value = min(10000, max_points // 10) # 10% of max points as default
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
+ label=f"Number of 3D points (max: {max_points:,})")
+
+@spaces.GPU
+def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000):
+ """Create a point cloud from depth map using camera intrinsics with high detail"""
+ h, w = depth_map.shape
+
+ # Use smaller step for higher detail (reduced downsampling)
+ step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ x_cam = (x_coords - w / 2) / focal_length_x
+ y_cam = (y_coords - h / 2) / focal_length_y
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors = image_colors.reshape(-1, 3) / 255.0
+
+ # Create Open3D point cloud
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+@spaces.GPU
+def reconstruct_surface_mesh_from_point_cloud(pcd):
+ """Convert point cloud to a mesh using Poisson reconstruction with very high detail."""
+ # Estimate and orient normals with high precision
+ pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50))
+ pcd.orient_normals_consistent_tangent_plane(k=50)
+
+ # Create surface mesh with maximum detail (depth=12 for very high resolution)
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12)
+
+ # Return mesh without filtering low-density vertices
+ return mesh
+
+@spaces.GPU
+def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
+ """Create an enhanced 3D visualization using proper camera projection"""
+ h, w = depth_map.shape
+
+ # Downsample to avoid too many points for performance
+ step = max(1, int(np.sqrt(h * w / max_points)))
+
+ # Create mesh grid for camera coordinates
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
+
+ # Convert to camera coordinates (normalized by focal length)
+ focal_length = 470.4 # Default focal length
+ x_cam = (x_coords - w / 2) / focal_length
+ y_cam = (y_coords - h / 2) / focal_length
+
+ # Get depth values
+ depth_values = depth_map[::step, ::step]
+
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
+ x_3d = x_cam * depth_values
+ y_3d = y_cam * depth_values
+ z_3d = depth_values
+
+ # Flatten arrays
+ x_flat = x_3d.flatten()
+ y_flat = y_3d.flatten()
+ z_flat = z_3d.flatten()
+
+ # Get corresponding image colors
+ image_colors = image[::step, ::step, :]
+ colors_flat = image_colors.reshape(-1, 3)
+
+ # Create 3D scatter plot with proper camera projection
+ fig = go.Figure(data=[go.Scatter3d(
+ x=x_flat,
+ y=y_flat,
+ z=z_flat,
+ mode='markers',
+ marker=dict(
+ size=1.5,
+ color=colors_flat,
+ opacity=0.9
+ ),
+ hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' +
+ 'Depth: %{z:.2f}
' +
+ ''
+ )])
+
+ fig.update_layout(
+ title="3D Point Cloud Visualization (Camera Projection)",
+ scene=dict(
+ xaxis_title="X (meters)",
+ yaxis_title="Y (meters)",
+ zaxis_title="Z (meters)",
+ camera=dict(
+ eye=dict(x=2.0, y=2.0, z=2.0),
+ center=dict(x=0, y=0, z=0),
+ up=dict(x=0, y=0, z=1)
+ ),
+ aspectmode='data'
+ ),
+ width=700,
+ height=600
+ )
+
+ return fig
+
+def on_depth_submit(image, num_points, focal_x, focal_y):
+ original_image = image.copy()
+
+ h, w = image.shape[:2]
+
+ # Predict depth using the model
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+
+ # Save raw 16-bit depth
+ raw_depth = Image.fromarray(depth.astype('uint16'))
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ raw_depth.save(tmp_raw_depth.name)
+
+ # Normalize and convert to grayscale for display
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ norm_depth = norm_depth.astype(np.uint8)
+ colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8)
+
+ gray_depth = Image.fromarray(norm_depth)
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
+ gray_depth.save(tmp_gray_depth.name)
+
+ # Create point cloud
+ pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points)
+
+ # Reconstruct mesh from point cloud
+ mesh = reconstruct_surface_mesh_from_point_cloud(pcd)
+
+ # Save mesh with faces as .ply
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
+ o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh)
+
+ # Create enhanced 3D scatter plot visualization
+ depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points)
+
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
+
+# --- Automatic Wound Mask Generation Functions ---
+import cv2
+from skimage import filters, morphology, measure
+from skimage.segmentation import clear_border
+
+def create_automatic_wound_mask(image, method='adaptive'):
+ """
+ Automatically generate wound mask from image using various segmentation methods
+
+ Args:
+ image: Input image (numpy array)
+ method: Segmentation method ('adaptive', 'otsu', 'color', 'combined')
+
+ Returns:
+ mask: Binary wound mask
+ """
+ if image is None:
+ return None
+
+ # Convert to grayscale if needed
+ if len(image.shape) == 3:
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+ else:
+ gray = image.copy()
+
+ # Apply different segmentation methods
+ if method == 'adaptive':
+ mask = adaptive_threshold_segmentation(gray)
+ elif method == 'otsu':
+ mask = otsu_threshold_segmentation(gray)
+ elif method == 'color':
+ mask = color_based_segmentation(image)
+ elif method == 'combined':
+ mask = combined_segmentation(image, gray)
+ else:
+ mask = adaptive_threshold_segmentation(gray)
+
+ return mask
+
+def adaptive_threshold_segmentation(gray):
+ """Use adaptive thresholding for wound segmentation"""
+ # Apply Gaussian blur to reduce noise
+ blurred = cv2.GaussianBlur(gray, (15, 15), 0)
+
+ # Adaptive thresholding with larger block size
+ thresh = cv2.adaptiveThreshold(
+ blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5
+ )
+
+ # Morphological operations to clean up the mask
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
+ mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Find contours and keep only the largest ones
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Create a new mask with only large contours
+ mask_clean = np.zeros_like(mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 1000: # Minimum area threshold
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def otsu_threshold_segmentation(gray):
+ """Use Otsu's thresholding for wound segmentation"""
+ # Apply Gaussian blur
+ blurred = cv2.GaussianBlur(gray, (15, 15), 0)
+
+ # Otsu's thresholding
+ _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+
+ # Morphological operations
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Find contours and keep only the largest ones
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Create a new mask with only large contours
+ mask_clean = np.zeros_like(mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 800: # Minimum area threshold
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def color_based_segmentation(image):
+ """Use color-based segmentation for wound detection"""
+ # Convert to different color spaces
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+
+ # Create masks for different color ranges (wound-like colors)
+ # Reddish/brownish wound colors in HSV - broader ranges
+ lower_red1 = np.array([0, 30, 30])
+ upper_red1 = np.array([15, 255, 255])
+ lower_red2 = np.array([160, 30, 30])
+ upper_red2 = np.array([180, 255, 255])
+
+ mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
+ mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
+ red_mask = mask1 + mask2
+
+ # Yellowish wound colors - broader range
+ lower_yellow = np.array([15, 30, 30])
+ upper_yellow = np.array([35, 255, 255])
+ yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
+
+ # Brownish wound colors
+ lower_brown = np.array([10, 50, 20])
+ upper_brown = np.array([20, 255, 200])
+ brown_mask = cv2.inRange(hsv, lower_brown, upper_brown)
+
+ # Combine color masks
+ color_mask = red_mask + yellow_mask + brown_mask
+
+ # Clean up the mask with larger kernels
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel)
+ color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel)
+
+ # Find contours and keep only the largest ones
+ contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Create a new mask with only large contours
+ mask_clean = np.zeros_like(color_mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 600: # Minimum area threshold
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ return mask_clean
+
+def combined_segmentation(image, gray):
+ """Combine multiple segmentation methods for better results"""
+ # Get masks from different methods
+ adaptive_mask = adaptive_threshold_segmentation(gray)
+ otsu_mask = otsu_threshold_segmentation(gray)
+ color_mask = color_based_segmentation(image)
+
+ # Combine masks (union)
+ combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask)
+ combined_mask = cv2.bitwise_or(combined_mask, color_mask)
+
+ # Apply additional morphological operations to clean up
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
+ combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
+
+ # Find contours and keep only the largest ones
+ contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Create a new mask with only large contours
+ mask_clean = np.zeros_like(combined_mask)
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area > 500: # Minimum area threshold
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ # If no large contours found, create a realistic wound mask
+ if np.sum(mask_clean) == 0:
+ mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical')
+
+ return mask_clean
+
+def post_process_wound_mask(mask, min_area=100):
+ """Post-process the wound mask to remove noise and small objects"""
+ if mask is None:
+ return None
+
+ # Convert to binary if needed
+ if mask.dtype != np.uint8:
+ mask = mask.astype(np.uint8)
+
+ # Apply morphological operations to clean up
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+
+ # Remove small objects using OpenCV
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ mask_clean = np.zeros_like(mask)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area >= min_area:
+ cv2.fillPoly(mask_clean, [contour], 255)
+
+ # Fill holes
+ mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel)
+
+ return mask_clean
+
+def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='combined'):
+ """Analyze wound severity with automatic mask generation"""
+ if image is None or depth_map is None:
+ return "❌ Please provide both image and depth map."
+
+ # Generate automatic wound mask
+ auto_mask = create_automatic_wound_mask(image, method=segmentation_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask."
+
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=500)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected. Try adjusting segmentation parameters or upload a manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm)
+
+# --- Main Gradio Interface ---
+with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo:
+ gr.HTML("Wound Analysis & Depth Estimation System
")
+ gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities")
+
+ # Shared image state
+ shared_image = gr.State()
+
+ with gr.Tabs():
+ # Tab 1: Wound Classification
+ with gr.Tab("1. Wound Classification"):
+ gr.Markdown("### Step 1: Upload and classify your wound image")
+ gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350)
+
+ with gr.Column(scale=1):
+ wound_prediction_box = gr.HTML()
+ wound_reasoning_box = gr.HTML()
+
+ # Button to pass image to depth estimation
+ with gr.Row():
+ pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg")
+ pass_status = gr.HTML("")
+
+ wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input,
+ outputs=[wound_prediction_box, wound_reasoning_box])
+
+ # Store image when uploaded for classification
+ wound_image_input.change(
+ fn=lambda img: img,
+ inputs=[wound_image_input],
+ outputs=[shared_image]
+ )
+
+ # Tab 2: Depth Estimation
+ with gr.Tab("2. Depth Estimation & 3D Visualization"):
+ gr.Markdown("### Step 2: Generate depth maps and 3D visualizations")
+ gr.Markdown("This module creates depth maps and 3D point clouds from your images.")
+
+ with gr.Row():
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
+
+ with gr.Row():
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
+ load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary")
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
+ label="Number of 3D points (upload image to update max)")
+
+ with gr.Row():
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length X (pixels)")
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
+ label="Focal Length Y (pixels)")
+
+ with gr.Row():
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
+
+ # 3D Visualization
+ gr.Markdown("### 3D Point Cloud Visualization")
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
+
+ # Store depth map for severity analysis
+ depth_map_state = gr.State()
+
+ # Tab 3: Wound Severity Analysis
+ with gr.Tab("3. 🩹 Wound Severity Analysis"):
+ gr.Markdown("### Step 3: Analyze wound severity using depth maps")
+ gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.")
+
+ with gr.Row():
+ severity_input_image = gr.Image(label="Original Image", type='numpy')
+ severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy')
+
+ with gr.Row():
+ wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy')
+ severity_output = gr.HTML(label="Severity Analysis Report")
+
+ gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.")
+
+ with gr.Row():
+ auto_severity_button = gr.Button("🤖 Auto-Analyze Severity", variant="primary", size="lg")
+ manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg")
+ pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1,
+ label="Pixel Spacing (mm/pixel)")
+
+ gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.")
+
+ with gr.Row():
+ segmentation_method = gr.Dropdown(
+ choices=["combined", "adaptive", "otsu", "color"],
+ value="combined",
+ label="Segmentation Method",
+ info="Choose automatic segmentation method"
+ )
+ min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100,
+ label="Minimum Area (pixels)",
+ info="Minimum wound area to detect")
+
+ with gr.Row():
+ # Load depth map from previous tab
+ load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary")
+ sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary")
+ realistic_mask_btn = gr.Button("🏥 Generate Realistic Mask", variant="secondary")
+ preview_mask_btn = gr.Button("👁️ Preview Auto Mask", variant="secondary")
+
+ gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.")
+
+ # Generate sample mask function
+ def generate_sample_mask(image):
+ if image is None:
+ return None, "❌ Please load an image first."
+
+ sample_mask = create_sample_wound_mask(image.shape)
+ return sample_mask, "✅ Sample circular wound mask generated!"
+
+ # Generate realistic mask function
+ def generate_realistic_mask(image):
+ if image is None:
+ return None, "❌ Please load an image first."
+
+ realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical')
+ return realistic_mask, "✅ Realistic elliptical wound mask generated!"
+
+ sample_mask_btn.click(
+ fn=generate_sample_mask,
+ inputs=[severity_input_image],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+ realistic_mask_btn.click(
+ fn=generate_realistic_mask,
+ inputs=[severity_input_image],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+ # Update slider when image is uploaded
+ depth_input_image.change(
+ fn=update_slider_on_image_upload,
+ inputs=[depth_input_image],
+ outputs=[points_slider]
+ )
+
+ # Modified depth submit function to store depth map
+ def on_depth_submit_with_state(image, num_points, focal_x, focal_y):
+ results = on_depth_submit(image, num_points, focal_x, focal_y)
+ # Extract depth map from results for severity analysis
+ depth_map = None
+ if image is not None:
+ depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed
+ # Normalize depth for severity analysis
+ norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth_map = norm_depth.astype(np.uint8)
+ return results + [depth_map]
+
+ depth_submit.click(on_depth_submit_with_state,
+ inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y],
+ outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state])
+
+ # Load depth map to severity tab
+ def load_depth_to_severity(depth_map, original_image):
+ if depth_map is None:
+ return None, None, "❌ No depth map available. Please compute depth in Tab 2 first."
+ return depth_map, original_image, "✅ Depth map loaded successfully!"
+
+ load_depth_btn.click(
+ fn=load_depth_to_severity,
+ inputs=[depth_map_state, depth_input_image],
+ outputs=[severity_depth_map, severity_input_image, gr.HTML()]
+ )
+
+ # Automatic severity analysis function
+ def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+
+ # Update post-processing with user-defined minimum area
+ def post_process_with_area(mask):
+ return post_process_wound_mask(mask, min_area=min_area)
+
+ # Generate automatic wound mask
+ auto_mask = create_automatic_wound_mask(image, method=seg_method)
+
+ if auto_mask is None:
+ return "❌ Failed to generate automatic wound mask."
+
+ # Post-process the mask
+ processed_mask = post_process_with_area(auto_mask)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask."
+
+ # Analyze severity using the automatic mask
+ return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing)
+
+ # Manual severity analysis function
+ def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing):
+ if depth_map is None:
+ return "❌ Please load depth map from Tab 2 first."
+ if wound_mask is None:
+ return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)."
+
+ return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing)
+
+ # Preview automatic mask function
+ def preview_auto_mask(image, seg_method, min_area):
+ if image is None:
+ return None, "❌ Please load an image first."
+
+ # Generate automatic wound mask
+ auto_mask = create_automatic_wound_mask(image, method=seg_method)
+
+ if auto_mask is None:
+ return None, "❌ Failed to generate automatic wound mask."
+
+ # Post-process the mask
+ processed_mask = post_process_wound_mask(auto_mask, min_area=min_area)
+
+ if processed_mask is None or np.sum(processed_mask > 0) == 0:
+ return None, "❌ No wound region detected. Try adjusting parameters."
+
+ return processed_mask, f"✅ Auto mask generated using {seg_method} method!"
+
+ # Connect event handlers
+ auto_severity_button.click(
+ fn=run_auto_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider,
+ segmentation_method, min_area_slider],
+ outputs=[severity_output]
+ )
+
+ manual_severity_button.click(
+ fn=run_manual_severity_analysis,
+ inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider],
+ outputs=[severity_output]
+ )
+
+ preview_mask_btn.click(
+ fn=preview_auto_mask,
+ inputs=[severity_input_image, segmentation_method, min_area_slider],
+ outputs=[wound_mask_input, gr.HTML()]
+ )
+
+ # Load shared image from classification tab
+ def load_shared_image(shared_img):
+ if shared_img is None:
+ return gr.Image(), "❌ No image available from classification tab"
+
+ # Convert PIL image to numpy array for depth estimation
+ if hasattr(shared_img, 'convert'):
+ # It's a PIL image, convert to numpy
+ img_array = np.array(shared_img)
+ return img_array, "✅ Image loaded from classification tab"
+ else:
+ # Already numpy array
+ return shared_img, "✅ Image loaded from classification tab"
+
+ load_shared_btn.click(
+ fn=load_shared_image,
+ inputs=[shared_image],
+ outputs=[depth_input_image, gr.HTML()]
+ )
+
+ # Pass image to depth tab function
+ def pass_image_to_depth(img):
+ if img is None:
+ return "❌ No image uploaded in classification tab"
+ return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'"
+
+ pass_to_depth_btn.click(
+ fn=pass_image_to_depth,
+ inputs=[shared_image],
+ outputs=[pass_status]
+ )
+
+if __name__ == '__main__':
+ demo.queue().launch(
+ server_name="0.0.0.0",
+ server_port=7860,
+ share=True
+ )
\ No newline at end of file
diff --git a/temp_files/train.py b/temp_files/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac4e808b034f1397392664df5cb0494c37da0ad
--- /dev/null
+++ b/temp_files/train.py
@@ -0,0 +1,69 @@
+from tensorflow.keras.optimizers import Adam
+from keras.callbacks import EarlyStopping
+from keras.models import load_model
+from keras.utils.generic_utils import CustomObjectScope
+
+from models.unets import Unet2D
+from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling
+from models.FCN import FCN_Vgg16_16s
+from models.SegNet import SegNet
+
+from utils.learning.metrics import dice_coef, precision, recall
+from utils.learning.losses import dice_coef_loss
+from utils.io.data import DataGen, save_results, save_history, load_data
+
+
+# manually set cuda 10.0 path
+#os.system('export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}')
+#os.system('export PATH=/usr/local/cuda-10.0/bin:/usr/local/cuda-10.0/NsightCompute-1.0${PATH:+:${PATH}}')
+
+# Varibales and data generator
+input_dim_x=224
+input_dim_y=224
+n_filters = 32
+dataset = 'Medetec_foot_ulcer_224'
+data_gen = DataGen('./data/' + dataset + '/', split_ratio=0.2, x=input_dim_x, y=input_dim_y)
+
+######### Get the deep learning models #########
+
+######### Unet ##########
+# unet2d = Unet2D(n_filters=n_filters, input_dim_x=None, input_dim_y=None, num_channels=3)
+# model, model_name = unet2d.get_unet_model_yuanqing()
+
+######### MobilenetV2 ##########
+model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1)
+model_name = 'MobilenetV2'
+with CustomObjectScope({'relu6': relu6,'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling}):
+ model = load_model('training_history/2019-12-19 01%3A53%3A15.480800.hdf5'
+ , custom_objects={'dice_coef': dice_coef, 'precision':precision, 'recall':recall})
+
+######### Vgg16 ##########
+# model, model_name = FCN_Vgg16_16s(input_shape=(input_dim_x, input_dim_y, 3))
+
+######### SegNet ##########
+# segnet = SegNet(n_filters, input_dim_x, input_dim_y, num_channels=3)
+# model, model_name = segnet.get_SegNet()
+
+# plot_model(model, to_file=model_name+'.png')
+
+# training
+batch_size = 2
+epochs = 2000
+learning_rate = 1e-4
+loss = 'binary_crossentropy'
+
+es = EarlyStopping(monitor='val_dice_coef', patience=200, mode='max', restore_best_weights=True)
+#training_history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs
+# , validation_split=0.2, verbose=1, callbacks=[])
+
+model.summary()
+model.compile(optimizer=Adam(lr=learning_rate), loss=loss, metrics=[dice_coef, precision, recall])
+training_history = model.fit_generator(data_gen.generate_data(batch_size=batch_size, train=True),
+ steps_per_epoch=int(data_gen.get_num_data_points(train=True) / batch_size),
+ callbacks=[es],
+ validation_data=data_gen.generate_data(batch_size=batch_size, val=True),
+ validation_steps=int(data_gen.get_num_data_points(val=True) / batch_size),
+ epochs=epochs)
+### save the model weight file and its training history
+save_history(model, model_name, training_history, dataset, n_filters, epochs, learning_rate, loss, color_space='RGB',
+ path='./training_history/')
\ No newline at end of file
diff --git a/training_history/2019-12-19 01%3A53%3A15.480800.hdf5 b/training_history/2019-12-19 01%3A53%3A15.480800.hdf5
new file mode 100644
index 0000000000000000000000000000000000000000..05a8764ec6b27ffa6284797028e98b2e4f3dc060
--- /dev/null
+++ b/training_history/2019-12-19 01%3A53%3A15.480800.hdf5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6eedcbb8201def963e45e36721b9d23bb604c2f1a2c39331f80e0541187dd287
+size 26126696
diff --git a/training_history/2025-08-07_16-25-27.hdf5 b/training_history/2025-08-07_16-25-27.hdf5
new file mode 100644
index 0000000000000000000000000000000000000000..90356f735ba02cbcdf6eb0801b5623efce4fbe2f
--- /dev/null
+++ b/training_history/2025-08-07_16-25-27.hdf5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7cbd78f8b4dfeb43af1f6ba8167b2650acb4cec2cca7c8a3c71000dcfd63908
+size 26362704
diff --git a/training_history/2025-08-07_16-25-27.json b/training_history/2025-08-07_16-25-27.json
new file mode 100644
index 0000000000000000000000000000000000000000..dd1b5a8e7e884a644657c9b483f2c952943e8068
--- /dev/null
+++ b/training_history/2025-08-07_16-25-27.json
@@ -0,0 +1,3946 @@
+{
+ "loss": [
+ 0.30647748708724976,
+ 0.1750967651605606,
+ 0.1348150670528412,
+ 0.10996372997760773,
+ 0.0961669534444809,
+ 0.08511081337928772,
+ 0.07905993610620499,
+ 0.07364209741353989,
+ 0.06904656440019608,
+ 0.0664418488740921,
+ 0.0637221485376358,
+ 0.06133847311139107,
+ 0.05951695144176483,
+ 0.05743888020515442,
+ 0.05550319701433182,
+ 0.05368385836482048,
+ 0.05071243271231651,
+ 0.053955163806676865,
+ 0.05913974717259407,
+ 0.05481218546628952,
+ 0.053216658532619476,
+ 0.051806069910526276,
+ 0.05040912330150604,
+ 0.04853715002536774,
+ 0.04627740755677223,
+ 0.044616397470235825,
+ 0.043374769389629364,
+ 0.04218907281756401,
+ 0.04144686087965965,
+ 0.04063361510634422,
+ 0.04707714170217514,
+ 0.04899223521351814,
+ 0.04533304274082184,
+ 0.043027348816394806,
+ 0.041465017944574356,
+ 0.040123045444488525,
+ 0.03773369640111923,
+ 0.03661240264773369,
+ 0.03559977933764458,
+ 0.034923210740089417,
+ 0.035055261105298996,
+ 0.03570236638188362,
+ 0.03435955196619034,
+ 0.03357187658548355,
+ 0.0329667367041111,
+ 0.03242901340126991,
+ 0.03186548501253128,
+ 0.03149624541401863,
+ 0.031010165810585022,
+ 0.03060271590948105,
+ 0.030267074704170227,
+ 0.03771795332431793,
+ 0.03534615784883499,
+ 0.033918071538209915,
+ 0.03289610520005226,
+ 0.032201554626226425,
+ 0.0314098559319973,
+ 0.03094932995736599,
+ 0.030538897961378098,
+ 0.029941964894533157,
+ 0.029633009806275368,
+ 0.02943609468638897,
+ 0.028660012409090996,
+ 0.028602395206689835,
+ 0.028005452826619148,
+ 0.02752712368965149,
+ 0.027221843600273132,
+ 0.02680036798119545,
+ 0.026418132707476616,
+ 0.02618495002388954,
+ 0.025956127792596817,
+ 0.02560962177813053,
+ 0.025426078587770462,
+ 0.02505660429596901,
+ 0.024684013798832893,
+ 0.024436764419078827,
+ 0.024140072986483574,
+ 0.023672718554735184,
+ 0.023578155785799026,
+ 0.023109160363674164,
+ 0.02283332496881485,
+ 0.02255292609333992,
+ 0.022325903177261353,
+ 0.021962279453873634,
+ 0.021811343729496002,
+ 0.02172011323273182,
+ 0.02129892073571682,
+ 0.021038774400949478,
+ 0.020723635330796242,
+ 0.02045297622680664,
+ 0.0214977003633976,
+ 0.021355675533413887,
+ 0.021544352173805237,
+ 0.020594464614987373,
+ 0.020267758518457413,
+ 0.021381760016083717,
+ 0.020386535674333572,
+ 0.019776318222284317,
+ 0.019427292048931122,
+ 0.019123896956443787,
+ 0.02659258246421814,
+ 0.03618019074201584,
+ 0.027225933969020844,
+ 0.025075973942875862,
+ 0.02390117757022381,
+ 0.023283500224351883,
+ 0.022889306768774986,
+ 0.022487878799438477,
+ 0.02206781506538391,
+ 0.021616896614432335,
+ 0.02290518395602703,
+ 0.02203931286931038,
+ 0.021298978477716446,
+ 0.02082144469022751,
+ 0.020476309582591057,
+ 0.020369667559862137,
+ 0.020366981625556946,
+ 0.020146096125245094,
+ 0.019704287871718407,
+ 0.01959238573908806,
+ 0.01938951388001442,
+ 0.019204184412956238,
+ 0.018987014889717102,
+ 0.018942462280392647,
+ 0.018814796581864357,
+ 0.0185571126639843,
+ 0.01836971566081047,
+ 0.018259204924106598,
+ 0.018205296248197556,
+ 0.018134118989109993,
+ 0.017913157120347023,
+ 0.017672620713710785,
+ 0.017590902745723724,
+ 0.017439354211091995,
+ 0.01727035455405712,
+ 0.017223920673131943,
+ 0.018535830080509186,
+ 0.01764857769012451,
+ 0.017293628305196762,
+ 0.017154958099126816,
+ 0.016896503046154976,
+ 0.016868583858013153,
+ 0.016872286796569824,
+ 0.016433054581284523,
+ 0.016220392659306526,
+ 0.01603561080992222,
+ 0.015778077766299248,
+ 0.01577850617468357,
+ 0.015843817964196205,
+ 0.015809863805770874,
+ 0.015860065817832947,
+ 0.015477662906050682,
+ 0.01531931571662426,
+ 0.015125595033168793,
+ 0.01502202544361353,
+ 0.014976686798036098,
+ 0.015033239498734474,
+ 0.015951665118336678,
+ 0.01540245022624731,
+ 0.014854921028017998,
+ 0.014961468987166882,
+ 0.015013501979410648,
+ 0.01517416164278984,
+ 0.01449889037758112,
+ 0.014303860254585743,
+ 0.014187045395374298,
+ 0.014804028905928135,
+ 0.015959901735186577,
+ 0.020786097273230553,
+ 0.01760375313460827,
+ 0.015779541805386543,
+ 0.015848513692617416,
+ 0.015986243262887,
+ 0.014567948877811432,
+ 0.013230666518211365,
+ 0.01453038863837719,
+ 0.013637007214128971,
+ 0.013031593523919582,
+ 0.01634998433291912,
+ 0.02193347178399563,
+ 0.016191652044653893,
+ 0.01508353091776371,
+ 0.015239519998431206,
+ 0.015201173722743988,
+ 0.014139600098133087,
+ 0.013658924028277397,
+ 0.013358766213059425,
+ 0.013011476956307888,
+ 0.012659418396651745,
+ 0.012393197976052761,
+ 0.012266334146261215,
+ 0.011977892369031906,
+ 0.011970474384725094,
+ 0.011703734286129475,
+ 0.01161114126443863,
+ 0.01172381266951561,
+ 0.011709300801157951,
+ 0.012194477021694183,
+ 0.013468567281961441,
+ 0.017431264743208885,
+ 0.015250993892550468,
+ 0.013966160826385021,
+ 0.012959204614162445,
+ 0.012426641769707203,
+ 0.012011424638330936,
+ 0.012029320932924747,
+ 0.012592227198183537,
+ 0.011803402565419674,
+ 0.011517093516886234,
+ 0.011947079561650753,
+ 0.012068829499185085,
+ 0.01242718007415533,
+ 0.012240450829267502,
+ 0.011861618608236313,
+ 0.012352905236184597,
+ 0.014822003431618214,
+ 0.019554194062948227,
+ 0.021305328235030174,
+ 0.019731786102056503,
+ 0.01814408227801323,
+ 0.01708049699664116,
+ 0.016473757103085518,
+ 0.016690775752067566,
+ 0.02231513150036335,
+ 0.01974739506840706,
+ 0.017199721187353134,
+ 0.016205912455916405,
+ 0.01561315730214119,
+ 0.01506352610886097,
+ 0.014831586740911007,
+ 0.014372607693076134,
+ 0.014147626236081123,
+ 0.013727042824029922,
+ 0.01352244708687067,
+ 0.01320627424865961,
+ 0.013022632338106632,
+ 0.012837468646466732,
+ 0.012507705949246883,
+ 0.01239242497831583,
+ 0.012630939483642578,
+ 0.012593026272952557,
+ 0.012325061485171318,
+ 0.011994406580924988,
+ 0.011810517869889736,
+ 0.012759208679199219,
+ 0.011985184624791145,
+ 0.011915159411728382,
+ 0.013752062804996967,
+ 0.01360330730676651,
+ 0.012521203607320786,
+ 0.012182645499706268,
+ 0.011708719655871391,
+ 0.01152396947145462,
+ 0.011431052349507809,
+ 0.01117919571697712,
+ 0.011751057580113411,
+ 0.011490956880152225,
+ 0.011127873323857784,
+ 0.010874041356146336,
+ 0.0110195092856884,
+ 0.01136202085763216,
+ 0.010967769660055637,
+ 0.011810488998889923,
+ 0.021257253363728523,
+ 0.014899597503244877,
+ 0.013231162913143635,
+ 0.01252503041177988,
+ 0.011877506040036678,
+ 0.011600157245993614,
+ 0.011508181691169739,
+ 0.011495032347738743,
+ 0.011990023776888847,
+ 0.011151742190122604,
+ 0.011049785651266575,
+ 0.010565276257693768,
+ 0.013197552412748337,
+ 0.01115148514509201,
+ 0.029514271765947342,
+ 0.0241073127835989,
+ 0.01918421871960163,
+ 0.017714563757181168,
+ 0.01667606830596924,
+ 0.015957504510879517,
+ 0.015396581031382084,
+ 0.014803696423768997,
+ 0.013729263097047806,
+ 0.013278660364449024,
+ 0.012890578247606754,
+ 0.012571698985993862,
+ 0.0128385741263628,
+ 0.012980268336832523,
+ 0.012489119544625282,
+ 0.012236335314810276,
+ 0.011874557472765446,
+ 0.011515924707055092,
+ 0.01126138586550951,
+ 0.011987834237515926,
+ 0.011403180658817291,
+ 0.011201435700058937,
+ 0.010884991846978664,
+ 0.010778272524476051,
+ 0.010546007193624973,
+ 0.010356712155044079,
+ 0.010204695165157318,
+ 0.010189410299062729,
+ 0.010785996913909912,
+ 0.011168578639626503,
+ 0.011988673359155655,
+ 0.013514974154531956,
+ 0.011950574815273285,
+ 0.015517404302954674,
+ 0.012958980165421963,
+ 0.012304777279496193,
+ 0.011725402437150478,
+ 0.01116760354489088,
+ 0.010796580463647842,
+ 0.011244084686040878,
+ 0.011185579001903534,
+ 0.010852617211639881,
+ 0.010554634034633636,
+ 0.010332541540265083,
+ 0.010245597921311855,
+ 0.010003700852394104,
+ 0.009881113655865192,
+ 0.009757373481988907,
+ 0.00969171617180109,
+ 0.00992400012910366,
+ 0.009720757603645325,
+ 0.009623757563531399,
+ 0.009466061368584633,
+ 0.009499983862042427,
+ 0.009426280856132507,
+ 0.009482208639383316,
+ 0.013372893445193768,
+ 0.021394873037934303,
+ 0.015852460637688637,
+ 0.014674946665763855,
+ 0.014396724291145802,
+ 0.01334143802523613,
+ 0.012487740255892277,
+ 0.011898113414645195,
+ 0.011559720151126385,
+ 0.011196346022188663,
+ 0.010986356064677238,
+ 0.010869316756725311,
+ 0.010591063648462296,
+ 0.01035912986844778,
+ 0.010249403305351734,
+ 0.010015510022640228,
+ 0.010098502971231937,
+ 0.010486401617527008,
+ 0.010132399387657642,
+ 0.010006209835410118,
+ 0.017435988411307335,
+ 0.017399724572896957,
+ 0.013613355346024036,
+ 0.01278015412390232,
+ 0.012173213995993137,
+ 0.011733791790902615,
+ 0.011255122721195221,
+ 0.010943155735731125,
+ 0.01070241816341877,
+ 0.010416466742753983,
+ 0.01040695607662201,
+ 0.010443135164678097,
+ 0.009866783395409584,
+ 0.009707099758088589,
+ 0.009678167290985584,
+ 0.009503369219601154,
+ 0.00959171261638403,
+ 0.009591631591320038,
+ 0.009851268492639065,
+ 0.00967353954911232,
+ 0.009600138291716576,
+ 0.03234858438372612,
+ 0.020381923764944077,
+ 0.017983023077249527,
+ 0.01669255830347538,
+ 0.015920067206025124,
+ 0.016266971826553345,
+ 0.014160492457449436,
+ 0.013295168988406658,
+ 0.012696294113993645,
+ 0.01237404253333807,
+ 0.01230611838400364,
+ 0.012054872699081898,
+ 0.011769881471991539,
+ 0.011479108594357967,
+ 0.0114207211881876,
+ 0.011093338951468468,
+ 0.010895829647779465,
+ 0.010720976628363132,
+ 0.011077550239861012,
+ 0.010639906860888004,
+ 0.010437337681651115,
+ 0.010433749295771122,
+ 0.010163159109652042,
+ 0.010187039151787758,
+ 0.009886456653475761,
+ 0.009900178760290146,
+ 0.009881138801574707,
+ 0.009746776893734932,
+ 0.009791776537895203,
+ 0.010411535389721394,
+ 0.010251682251691818,
+ 0.00973040796816349,
+ 0.014108603820204735,
+ 0.012864779680967331,
+ 0.011442308314144611,
+ 0.011878883466124535,
+ 0.01187659241259098,
+ 0.011125029996037483,
+ 0.010407320223748684,
+ 0.010011264123022556,
+ 0.019328519701957703,
+ 0.021474624052643776,
+ 0.01828955113887787,
+ 0.015556896105408669,
+ 0.014325638301670551,
+ 0.013566575944423676,
+ 0.012937955558300018,
+ 0.012533580884337425,
+ 0.012252584099769592,
+ 0.012734328396618366,
+ 0.012158636003732681,
+ 0.013013201765716076,
+ 0.011682095937430859,
+ 0.02510238066315651,
+ 0.018895352259278297,
+ 0.015548864379525185,
+ 0.014414222911000252,
+ 0.013655620627105236,
+ 0.013227966614067554,
+ 0.012928091920912266,
+ 0.01250703725963831,
+ 0.012381622567772865,
+ 0.01197823416441679,
+ 0.011751786805689335,
+ 0.01158289797604084,
+ 0.011270741000771523,
+ 0.011268547736108303,
+ 0.011189980432391167,
+ 0.011638487689197063,
+ 0.011096851900219917,
+ 0.010878192260861397,
+ 0.010670908726751804,
+ 0.010411555878818035,
+ 0.010200803168118,
+ 0.010321483016014099,
+ 0.010585369542241096,
+ 0.010327773168683052,
+ 0.010099920444190502,
+ 0.00985438097268343,
+ 0.009732011705636978,
+ 0.009680816903710365,
+ 0.009557006880640984,
+ 0.009599557146430016,
+ 0.009605410508811474,
+ 0.009399175643920898,
+ 0.009391609579324722,
+ 0.009294605813920498,
+ 0.00959872081875801,
+ 0.009268030524253845,
+ 0.009913635440170765,
+ 0.010886119678616524,
+ 0.010026099160313606,
+ 0.010056180879473686,
+ 0.009460094384849072,
+ 0.009112786501646042,
+ 0.024201123043894768,
+ 0.01934361457824707,
+ 0.013264663517475128,
+ 0.011855002492666245,
+ 0.011098966002464294,
+ 0.010679544880986214,
+ 0.012001180090010166,
+ 0.014273331500589848,
+ 0.013267090544104576,
+ 0.0121291633695364,
+ 0.011487667448818684,
+ 0.011008239351212978,
+ 0.010625632479786873,
+ 0.010453431867063046,
+ 0.010083537548780441,
+ 0.009872745722532272,
+ 0.009798270650207996,
+ 0.009803017601370811,
+ 0.009521342813968658,
+ 0.010055589489638805,
+ 0.009551863186061382,
+ 0.009455163963139057
+ ],
+ "dice_coef": [
+ 0.8271756768226624,
+ 0.8616277575492859,
+ 0.8727266192436218,
+ 0.8806064128875732,
+ 0.8851526379585266,
+ 0.8895143270492554,
+ 0.8929189443588257,
+ 0.8961383700370789,
+ 0.8982627987861633,
+ 0.9007869362831116,
+ 0.9030752778053284,
+ 0.9055670499801636,
+ 0.9069859385490417,
+ 0.9094352722167969,
+ 0.9106618165969849,
+ 0.9123764634132385,
+ 0.9143651723861694,
+ 0.9116659164428711,
+ 0.9102592468261719,
+ 0.914546549320221,
+ 0.9164059162139893,
+ 0.918033242225647,
+ 0.9189513921737671,
+ 0.9205445051193237,
+ 0.9218775033950806,
+ 0.9237352609634399,
+ 0.9249120354652405,
+ 0.9262439608573914,
+ 0.9274923205375671,
+ 0.9286374449729919,
+ 0.9185131788253784,
+ 0.9172464609146118,
+ 0.921724796295166,
+ 0.9248894453048706,
+ 0.9270980954170227,
+ 0.9278378486633301,
+ 0.9289146661758423,
+ 0.9303903579711914,
+ 0.9316245317459106,
+ 0.9325858950614929,
+ 0.9317297339439392,
+ 0.930961012840271,
+ 0.9337752461433411,
+ 0.935239851474762,
+ 0.9366147518157959,
+ 0.9373242259025574,
+ 0.9382657408714294,
+ 0.9389757513999939,
+ 0.9396278858184814,
+ 0.9405460953712463,
+ 0.9411396980285645,
+ 0.9279007315635681,
+ 0.9338310360908508,
+ 0.9359795451164246,
+ 0.9377782940864563,
+ 0.9386602640151978,
+ 0.9395512342453003,
+ 0.9404653310775757,
+ 0.9410924315452576,
+ 0.9420122504234314,
+ 0.9424967169761658,
+ 0.9431310296058655,
+ 0.9438883662223816,
+ 0.9442669749259949,
+ 0.9446057081222534,
+ 0.9453078508377075,
+ 0.946151852607727,
+ 0.9468348622322083,
+ 0.9474738240242004,
+ 0.9479974508285522,
+ 0.9481825232505798,
+ 0.9486918449401855,
+ 0.9491387009620667,
+ 0.9495266079902649,
+ 0.9503999948501587,
+ 0.9512287378311157,
+ 0.9517524838447571,
+ 0.9522303938865662,
+ 0.9523171186447144,
+ 0.9529317617416382,
+ 0.9539192914962769,
+ 0.954371988773346,
+ 0.9549437165260315,
+ 0.9555351138114929,
+ 0.9559608697891235,
+ 0.9559311866760254,
+ 0.9569058418273926,
+ 0.9574870467185974,
+ 0.9579500555992126,
+ 0.9585270285606384,
+ 0.9570472240447998,
+ 0.9570488929748535,
+ 0.9575705528259277,
+ 0.9589021801948547,
+ 0.9596315026283264,
+ 0.9579778909683228,
+ 0.9597327709197998,
+ 0.9607507586479187,
+ 0.9615135788917542,
+ 0.9620363116264343,
+ 0.9527720212936401,
+ 0.9445807933807373,
+ 0.9498860239982605,
+ 0.9536462426185608,
+ 0.9545617699623108,
+ 0.9558765888214111,
+ 0.9553714990615845,
+ 0.956228494644165,
+ 0.9569122791290283,
+ 0.9576302170753479,
+ 0.9567261338233948,
+ 0.957889974117279,
+ 0.9587068557739258,
+ 0.9591506719589233,
+ 0.9596977233886719,
+ 0.9597998261451721,
+ 0.9601463079452515,
+ 0.9599983096122742,
+ 0.960551917552948,
+ 0.9609813094139099,
+ 0.961475670337677,
+ 0.9618078470230103,
+ 0.9621520042419434,
+ 0.9621374607086182,
+ 0.962533712387085,
+ 0.9629316329956055,
+ 0.963316023349762,
+ 0.9636129140853882,
+ 0.9639310836791992,
+ 0.9640538692474365,
+ 0.9640945792198181,
+ 0.9645101428031921,
+ 0.9647799730300903,
+ 0.965129554271698,
+ 0.9654141664505005,
+ 0.9656355381011963,
+ 0.9649770855903625,
+ 0.9656959176063538,
+ 0.9662606716156006,
+ 0.966301441192627,
+ 0.9666382074356079,
+ 0.9668326377868652,
+ 0.9668439030647278,
+ 0.9674142599105835,
+ 0.9679332375526428,
+ 0.9681134819984436,
+ 0.9686338901519775,
+ 0.9687585234642029,
+ 0.968749463558197,
+ 0.9687435626983643,
+ 0.9692564606666565,
+ 0.9698631763458252,
+ 0.9701429009437561,
+ 0.9703571200370789,
+ 0.9706502556800842,
+ 0.9709163904190063,
+ 0.9707404375076294,
+ 0.9700429439544678,
+ 0.97032630443573,
+ 0.9712384343147278,
+ 0.9708642959594727,
+ 0.9709155559539795,
+ 0.9712903499603271,
+ 0.9723801016807556,
+ 0.972836971282959,
+ 0.9731794595718384,
+ 0.9727415442466736,
+ 0.9710344076156616,
+ 0.9692134261131287,
+ 0.9714216589927673,
+ 0.9724838137626648,
+ 0.9720187187194824,
+ 0.9717462658882141,
+ 0.972689688205719,
+ 0.9736983776092529,
+ 0.9719834923744202,
+ 0.9729598164558411,
+ 0.9742265343666077,
+ 0.9694671034812927,
+ 0.9642630815505981,
+ 0.9688175916671753,
+ 0.9700709581375122,
+ 0.9708917140960693,
+ 0.970836341381073,
+ 0.9719638228416443,
+ 0.9728350043296814,
+ 0.9731382727622986,
+ 0.9736285209655762,
+ 0.9742369651794434,
+ 0.9747639894485474,
+ 0.9749345779418945,
+ 0.9753822088241577,
+ 0.9758213758468628,
+ 0.9761431217193604,
+ 0.976209282875061,
+ 0.9762759804725647,
+ 0.9761236310005188,
+ 0.9762524366378784,
+ 0.973239004611969,
+ 0.9702832698822021,
+ 0.9724173545837402,
+ 0.9745905995368958,
+ 0.9755451083183289,
+ 0.9759634733200073,
+ 0.9766412973403931,
+ 0.976182222366333,
+ 0.9765947461128235,
+ 0.977360188961029,
+ 0.977577269077301,
+ 0.9766439199447632,
+ 0.976081132888794,
+ 0.9756816625595093,
+ 0.9761340022087097,
+ 0.9771096110343933,
+ 0.9764305353164673,
+ 0.9725992679595947,
+ 0.9646629691123962,
+ 0.9655504822731018,
+ 0.9679058790206909,
+ 0.9697939157485962,
+ 0.9714179039001465,
+ 0.9719633460044861,
+ 0.9701056480407715,
+ 0.959028959274292,
+ 0.9631887674331665,
+ 0.9672734141349792,
+ 0.9687532782554626,
+ 0.9695519208908081,
+ 0.9703459739685059,
+ 0.970940887928009,
+ 0.9715447425842285,
+ 0.9717562198638916,
+ 0.9725369811058044,
+ 0.9730168581008911,
+ 0.9734640717506409,
+ 0.9738755226135254,
+ 0.9740811586380005,
+ 0.9746923446655273,
+ 0.9748530983924866,
+ 0.9743151068687439,
+ 0.9750608205795288,
+ 0.9753448963165283,
+ 0.9757680892944336,
+ 0.9758749604225159,
+ 0.9755315184593201,
+ 0.9764375686645508,
+ 0.9760685563087463,
+ 0.9738132953643799,
+ 0.9744699001312256,
+ 0.9759675860404968,
+ 0.9765239953994751,
+ 0.9767493009567261,
+ 0.9770404100418091,
+ 0.9775190949440002,
+ 0.9777635335922241,
+ 0.9770656824111938,
+ 0.9773379564285278,
+ 0.9780363440513611,
+ 0.9783704280853271,
+ 0.9783034324645996,
+ 0.9781065583229065,
+ 0.9786196351051331,
+ 0.9776173233985901,
+ 0.968602180480957,
+ 0.9737133383750916,
+ 0.975898265838623,
+ 0.9767024517059326,
+ 0.9774192571640015,
+ 0.9777238368988037,
+ 0.9777422547340393,
+ 0.977294385433197,
+ 0.9767982363700867,
+ 0.9779852032661438,
+ 0.9782487154006958,
+ 0.9789919257164001,
+ 0.9749283194541931,
+ 0.9774854779243469,
+ 0.957599401473999,
+ 0.9600788354873657,
+ 0.9646042585372925,
+ 0.9667049646377563,
+ 0.9683660864830017,
+ 0.9693924784660339,
+ 0.9704417586326599,
+ 0.9712101221084595,
+ 0.9723083972930908,
+ 0.9729383587837219,
+ 0.9738287925720215,
+ 0.9742386341094971,
+ 0.9735625386238098,
+ 0.9744496941566467,
+ 0.9751554131507874,
+ 0.9753341674804688,
+ 0.9757531881332397,
+ 0.9760571718215942,
+ 0.9765751361846924,
+ 0.9760973453521729,
+ 0.9768602848052979,
+ 0.9770902991294861,
+ 0.9773600697517395,
+ 0.9775993227958679,
+ 0.9781467318534851,
+ 0.9784204959869385,
+ 0.9787243008613586,
+ 0.9789025783538818,
+ 0.9783837199211121,
+ 0.9773847460746765,
+ 0.9759023785591125,
+ 0.9755476713180542,
+ 0.977429986000061,
+ 0.9732048511505127,
+ 0.9762001037597656,
+ 0.9772343039512634,
+ 0.9779145121574402,
+ 0.9783735275268555,
+ 0.9787331819534302,
+ 0.9782876968383789,
+ 0.9788635969161987,
+ 0.9793126583099365,
+ 0.9796106815338135,
+ 0.9797661304473877,
+ 0.9799916744232178,
+ 0.9803319573402405,
+ 0.9804670214653015,
+ 0.9807540774345398,
+ 0.980849027633667,
+ 0.9807034730911255,
+ 0.9808654189109802,
+ 0.9810879826545715,
+ 0.9812144637107849,
+ 0.9812115430831909,
+ 0.9814638495445251,
+ 0.981597363948822,
+ 0.9785327315330505,
+ 0.9670518636703491,
+ 0.9723511338233948,
+ 0.973235011100769,
+ 0.9736719727516174,
+ 0.9750344157218933,
+ 0.9761777520179749,
+ 0.9768736958503723,
+ 0.9774056077003479,
+ 0.9779927730560303,
+ 0.9782809019088745,
+ 0.9787391424179077,
+ 0.9790400266647339,
+ 0.9794601798057556,
+ 0.9793281555175781,
+ 0.9797995090484619,
+ 0.9799940586090088,
+ 0.9790902137756348,
+ 0.9799678921699524,
+ 0.9799731373786926,
+ 0.9702537655830383,
+ 0.9701027274131775,
+ 0.9748572111129761,
+ 0.9759320020675659,
+ 0.976621687412262,
+ 0.9774580597877502,
+ 0.977945864200592,
+ 0.9785280823707581,
+ 0.978829562664032,
+ 0.9791204333305359,
+ 0.9791039228439331,
+ 0.9786739945411682,
+ 0.9796222448348999,
+ 0.9800112247467041,
+ 0.9802066683769226,
+ 0.9802090525627136,
+ 0.9800991415977478,
+ 0.980035126209259,
+ 0.9800930619239807,
+ 0.9805180430412292,
+ 0.9804971814155579,
+ 0.9669134020805359,
+ 0.9731391668319702,
+ 0.9752886295318604,
+ 0.9763974547386169,
+ 0.9770158529281616,
+ 0.9739793539047241,
+ 0.97658771276474,
+ 0.9770677089691162,
+ 0.9779804944992065,
+ 0.9784687161445618,
+ 0.9784383773803711,
+ 0.9788373708724976,
+ 0.9792552590370178,
+ 0.979379415512085,
+ 0.979070782661438,
+ 0.9796465635299683,
+ 0.9798751473426819,
+ 0.9802157878875732,
+ 0.9796494245529175,
+ 0.980462372303009,
+ 0.980598509311676,
+ 0.9807146787643433,
+ 0.9805933833122253,
+ 0.9807178378105164,
+ 0.9807981252670288,
+ 0.9807450175285339,
+ 0.9807879328727722,
+ 0.9811874628067017,
+ 0.9808743596076965,
+ 0.9802898168563843,
+ 0.9803329706192017,
+ 0.9811297059059143,
+ 0.976192057132721,
+ 0.9774067997932434,
+ 0.9790591597557068,
+ 0.9782618880271912,
+ 0.978374719619751,
+ 0.9797973036766052,
+ 0.98039710521698,
+ 0.9809061288833618,
+ 0.9706665277481079,
+ 0.9683108925819397,
+ 0.9707096219062805,
+ 0.9731748104095459,
+ 0.9748567342758179,
+ 0.9759149551391602,
+ 0.9767023324966431,
+ 0.9775484204292297,
+ 0.9776878952980042,
+ 0.9763371348381042,
+ 0.9780442118644714,
+ 0.9775347113609314,
+ 0.9785792231559753,
+ 0.9640551209449768,
+ 0.9695783853530884,
+ 0.9730770587921143,
+ 0.9742680788040161,
+ 0.9752517938613892,
+ 0.9759684205055237,
+ 0.976382315158844,
+ 0.9770227074623108,
+ 0.9772208333015442,
+ 0.9777687191963196,
+ 0.9781318306922913,
+ 0.9782090187072754,
+ 0.9786636233329773,
+ 0.9788519144058228,
+ 0.9788020253181458,
+ 0.9782240390777588,
+ 0.9789761900901794,
+ 0.9791821241378784,
+ 0.9795763492584229,
+ 0.979709267616272,
+ 0.9801487326622009,
+ 0.9801005125045776,
+ 0.9800381660461426,
+ 0.9801545143127441,
+ 0.9805911779403687,
+ 0.980867862701416,
+ 0.981088399887085,
+ 0.9812235832214355,
+ 0.9811208844184875,
+ 0.981157124042511,
+ 0.9814913272857666,
+ 0.9817042350769043,
+ 0.9816209077835083,
+ 0.9817667007446289,
+ 0.9815982580184937,
+ 0.982085645198822,
+ 0.9803016185760498,
+ 0.9796533584594727,
+ 0.9806585907936096,
+ 0.9811287522315979,
+ 0.9816910624504089,
+ 0.9821954369544983,
+ 0.9700067639350891,
+ 0.9723506569862366,
+ 0.9768332839012146,
+ 0.9787658452987671,
+ 0.9797025918960571,
+ 0.9802789688110352,
+ 0.9781936407089233,
+ 0.9750186800956726,
+ 0.9768596291542053,
+ 0.9784777164459229,
+ 0.9793055057525635,
+ 0.9798623919487,
+ 0.9803575277328491,
+ 0.9806535840034485,
+ 0.981094479560852,
+ 0.9813281893730164,
+ 0.9809693098068237,
+ 0.9815014600753784,
+ 0.9817039966583252,
+ 0.9812902212142944,
+ 0.9819175601005554,
+ 0.9820971488952637
+ ],
+ "precision": [
+ 0.8992838859558105,
+ 0.8935195803642273,
+ 0.9231774806976318,
+ 0.9065283536911011,
+ 0.9022323489189148,
+ 0.9007483124732971,
+ 0.9058020114898682,
+ 0.9094322323799133,
+ 0.9126219749450684,
+ 0.9185943007469177,
+ 0.9219158291816711,
+ 0.9222309589385986,
+ 0.9274000525474548,
+ 0.9257815480232239,
+ 0.9265021085739136,
+ 0.9240222573280334,
+ 0.9306662678718567,
+ 0.9135439395904541,
+ 0.9231010675430298,
+ 0.9295929670333862,
+ 0.9293330907821655,
+ 0.9312353134155273,
+ 0.9329625964164734,
+ 0.9316965341567993,
+ 0.9333266615867615,
+ 0.9345052242279053,
+ 0.9335157871246338,
+ 0.9346805214881897,
+ 0.9368788003921509,
+ 0.9368543028831482,
+ 0.9517249464988708,
+ 0.95307457447052,
+ 0.9403975009918213,
+ 0.9448124766349792,
+ 0.9445593953132629,
+ 0.9379482865333557,
+ 0.9392950534820557,
+ 0.9407076239585876,
+ 0.9420888423919678,
+ 0.9438381791114807,
+ 0.9501326680183411,
+ 0.9543547630310059,
+ 0.949072003364563,
+ 0.9473968148231506,
+ 0.9440956711769104,
+ 0.9430425763130188,
+ 0.9459739327430725,
+ 0.9455015659332275,
+ 0.9466095566749573,
+ 0.9462694525718689,
+ 0.9438596367835999,
+ 0.9619185924530029,
+ 0.9508318901062012,
+ 0.9483709335327148,
+ 0.9470229744911194,
+ 0.946711003780365,
+ 0.948207437992096,
+ 0.9481415748596191,
+ 0.9484500288963318,
+ 0.9477959275245667,
+ 0.9477677345275879,
+ 0.9469801187515259,
+ 0.9493097066879272,
+ 0.9443789720535278,
+ 0.9479587078094482,
+ 0.9503016471862793,
+ 0.9498686194419861,
+ 0.9504828453063965,
+ 0.9518103003501892,
+ 0.9503151774406433,
+ 0.9508251547813416,
+ 0.9512123465538025,
+ 0.9516437649726868,
+ 0.9551058411598206,
+ 0.9545681476593018,
+ 0.9550942778587341,
+ 0.9535008668899536,
+ 0.9537196159362793,
+ 0.9563971161842346,
+ 0.9570322632789612,
+ 0.9570552110671997,
+ 0.9566587209701538,
+ 0.9578317999839783,
+ 0.959659993648529,
+ 0.9594024419784546,
+ 0.9616206884384155,
+ 0.960929274559021,
+ 0.9598709344863892,
+ 0.9609968066215515,
+ 0.96136873960495,
+ 0.9600357413291931,
+ 0.9573716521263123,
+ 0.9603555202484131,
+ 0.9617785215377808,
+ 0.9621409773826599,
+ 0.9566047191619873,
+ 0.9613975882530212,
+ 0.9627871513366699,
+ 0.9628319144248962,
+ 0.9638941287994385,
+ 0.9425920248031616,
+ 0.9338929653167725,
+ 0.9540751576423645,
+ 0.9525136351585388,
+ 0.9575161337852478,
+ 0.9593470692634583,
+ 0.9632865786552429,
+ 0.9643735289573669,
+ 0.9636847972869873,
+ 0.9605234861373901,
+ 0.9628630876541138,
+ 0.9579678177833557,
+ 0.9624496102333069,
+ 0.9615869522094727,
+ 0.9628499150276184,
+ 0.9642650485038757,
+ 0.9617869853973389,
+ 0.9646812081336975,
+ 0.9648491740226746,
+ 0.962059736251831,
+ 0.9631932973861694,
+ 0.9624291062355042,
+ 0.9625061750411987,
+ 0.9661656618118286,
+ 0.9649047255516052,
+ 0.9643188714981079,
+ 0.9657419919967651,
+ 0.9640787839889526,
+ 0.9652308225631714,
+ 0.9639834761619568,
+ 0.9667556285858154,
+ 0.96665358543396,
+ 0.9679520130157471,
+ 0.9679667353630066,
+ 0.9676694273948669,
+ 0.9669484496116638,
+ 0.9660540223121643,
+ 0.9672194719314575,
+ 0.9675654172897339,
+ 0.9686472415924072,
+ 0.9677742719650269,
+ 0.9648299217224121,
+ 0.9687601327896118,
+ 0.9699457883834839,
+ 0.9697255492210388,
+ 0.9699035286903381,
+ 0.9702296853065491,
+ 0.971717894077301,
+ 0.9693904519081116,
+ 0.9700163006782532,
+ 0.9713046550750732,
+ 0.9716519713401794,
+ 0.9721776247024536,
+ 0.9729169011116028,
+ 0.9725967049598694,
+ 0.9724439382553101,
+ 0.9721809029579163,
+ 0.968743622303009,
+ 0.9726381897926331,
+ 0.9733038544654846,
+ 0.975723385810852,
+ 0.9754957556724548,
+ 0.9738695621490479,
+ 0.973743200302124,
+ 0.9743053317070007,
+ 0.9736654162406921,
+ 0.9750998616218567,
+ 0.9715243577957153,
+ 0.9721106290817261,
+ 0.972126305103302,
+ 0.9748666286468506,
+ 0.9704681634902954,
+ 0.9725722074508667,
+ 0.9727873802185059,
+ 0.9750034213066101,
+ 0.9773173332214355,
+ 0.9747646450996399,
+ 0.9758877158164978,
+ 0.9750096797943115,
+ 0.9615429043769836,
+ 0.9669661521911621,
+ 0.9760112166404724,
+ 0.9738235473632812,
+ 0.9678298830986023,
+ 0.9739218950271606,
+ 0.9718801379203796,
+ 0.9750884175300598,
+ 0.9742626547813416,
+ 0.9749855399131775,
+ 0.9757564067840576,
+ 0.9747680425643921,
+ 0.9766591191291809,
+ 0.9760065674781799,
+ 0.9765374064445496,
+ 0.978253960609436,
+ 0.9755191206932068,
+ 0.9758433103561401,
+ 0.9768528938293457,
+ 0.9819539189338684,
+ 0.9725168347358704,
+ 0.9777063131332397,
+ 0.9763947129249573,
+ 0.9767307043075562,
+ 0.9769740700721741,
+ 0.9780659675598145,
+ 0.9795251488685608,
+ 0.9774302244186401,
+ 0.9785761833190918,
+ 0.9788815379142761,
+ 0.9764156937599182,
+ 0.9729133248329163,
+ 0.9722380042076111,
+ 0.9785850644111633,
+ 0.9786012768745422,
+ 0.9787552356719971,
+ 0.9730265736579895,
+ 0.9583626985549927,
+ 0.9634031653404236,
+ 0.9660843014717102,
+ 0.9704118371009827,
+ 0.9707321524620056,
+ 0.9709881544113159,
+ 0.9760051369667053,
+ 0.946617066860199,
+ 0.9593662619590759,
+ 0.9677178859710693,
+ 0.9701368808746338,
+ 0.9732000231742859,
+ 0.9716952443122864,
+ 0.9724082946777344,
+ 0.9722424149513245,
+ 0.9704092741012573,
+ 0.9729465842247009,
+ 0.9730241298675537,
+ 0.9735832214355469,
+ 0.9737854599952698,
+ 0.9745733737945557,
+ 0.9750858545303345,
+ 0.9747254252433777,
+ 0.9785636067390442,
+ 0.9759612083435059,
+ 0.9760167598724365,
+ 0.9758936762809753,
+ 0.9770054221153259,
+ 0.9762231707572937,
+ 0.9768044948577881,
+ 0.9787408113479614,
+ 0.9720087647438049,
+ 0.9759089946746826,
+ 0.976744532585144,
+ 0.9771106243133545,
+ 0.975960373878479,
+ 0.9786209464073181,
+ 0.9778503179550171,
+ 0.978689968585968,
+ 0.9750562310218811,
+ 0.9783962965011597,
+ 0.9785009026527405,
+ 0.9785699248313904,
+ 0.9773379564285278,
+ 0.9788073897361755,
+ 0.9803115129470825,
+ 0.9762491583824158,
+ 0.9689900875091553,
+ 0.9725742936134338,
+ 0.9754883050918579,
+ 0.9774740934371948,
+ 0.9774919748306274,
+ 0.9788755774497986,
+ 0.9765322804450989,
+ 0.9788328409194946,
+ 0.9757962822914124,
+ 0.9774616956710815,
+ 0.9787426590919495,
+ 0.9791474938392639,
+ 0.9792768359184265,
+ 0.9798372983932495,
+ 0.9480541944503784,
+ 0.9694900512695312,
+ 0.9659545421600342,
+ 0.9701856970787048,
+ 0.9682232737541199,
+ 0.9694339036941528,
+ 0.9717342853546143,
+ 0.9711283445358276,
+ 0.9724385142326355,
+ 0.9730356931686401,
+ 0.9737069010734558,
+ 0.9744369387626648,
+ 0.9692203998565674,
+ 0.9734945893287659,
+ 0.974539577960968,
+ 0.9749677181243896,
+ 0.9785441756248474,
+ 0.9759414792060852,
+ 0.9767823219299316,
+ 0.9748192429542542,
+ 0.9773659110069275,
+ 0.9795023798942566,
+ 0.9782940149307251,
+ 0.9800353646278381,
+ 0.978857696056366,
+ 0.9792131185531616,
+ 0.9789946675300598,
+ 0.9792091846466064,
+ 0.9785779118537903,
+ 0.9779298305511475,
+ 0.9727939367294312,
+ 0.9799074530601501,
+ 0.9778013825416565,
+ 0.9810697436332703,
+ 0.9760221242904663,
+ 0.9775710701942444,
+ 0.9781808257102966,
+ 0.9784017205238342,
+ 0.9790849685668945,
+ 0.975799024105072,
+ 0.9790262579917908,
+ 0.9796527028083801,
+ 0.9795625805854797,
+ 0.9797856211662292,
+ 0.9802121520042419,
+ 0.9808707237243652,
+ 0.9800721406936646,
+ 0.9813472628593445,
+ 0.9816632866859436,
+ 0.979587972164154,
+ 0.9817426204681396,
+ 0.9810745716094971,
+ 0.9811837673187256,
+ 0.9822655916213989,
+ 0.9821571707725525,
+ 0.981706440448761,
+ 0.9849334359169006,
+ 0.9618869423866272,
+ 0.972809910774231,
+ 0.9725103378295898,
+ 0.9719271659851074,
+ 0.9765286445617676,
+ 0.9767768979072571,
+ 0.9780049324035645,
+ 0.9781366586685181,
+ 0.9787758588790894,
+ 0.979574978351593,
+ 0.9788941144943237,
+ 0.9799464344978333,
+ 0.9802626371383667,
+ 0.9812312722206116,
+ 0.9805123805999756,
+ 0.9803711175918579,
+ 0.9846120476722717,
+ 0.9806753396987915,
+ 0.9805695414543152,
+ 0.9805625677108765,
+ 0.9703459739685059,
+ 0.9786406755447388,
+ 0.9758301377296448,
+ 0.9795531034469604,
+ 0.9787594676017761,
+ 0.9793882966041565,
+ 0.9798562526702881,
+ 0.9799196124076843,
+ 0.9809191823005676,
+ 0.9783082604408264,
+ 0.9760771989822388,
+ 0.9804887771606445,
+ 0.9809110760688782,
+ 0.9809779524803162,
+ 0.9790912866592407,
+ 0.979170024394989,
+ 0.9827489852905273,
+ 0.9835047125816345,
+ 0.9825785160064697,
+ 0.9837331771850586,
+ 0.9735047221183777,
+ 0.972905695438385,
+ 0.977183997631073,
+ 0.978089451789856,
+ 0.9775056838989258,
+ 0.9682573080062866,
+ 0.976538360118866,
+ 0.9770686030387878,
+ 0.9791834354400635,
+ 0.9798524975776672,
+ 0.9826685786247253,
+ 0.9796350002288818,
+ 0.9802651405334473,
+ 0.9800410270690918,
+ 0.9790262579917908,
+ 0.9812719821929932,
+ 0.9796724319458008,
+ 0.9810829162597656,
+ 0.9772170186042786,
+ 0.9808533191680908,
+ 0.9819177389144897,
+ 0.981721818447113,
+ 0.9786110520362854,
+ 0.9818451404571533,
+ 0.9833516478538513,
+ 0.978785514831543,
+ 0.9799115657806396,
+ 0.9822924733161926,
+ 0.9824753403663635,
+ 0.9776352047920227,
+ 0.9810264706611633,
+ 0.9814620614051819,
+ 0.9740288257598877,
+ 0.9787504076957703,
+ 0.9795571565628052,
+ 0.9814469814300537,
+ 0.9769492149353027,
+ 0.9803171157836914,
+ 0.9813071489334106,
+ 0.9816375970840454,
+ 0.9820874929428101,
+ 0.9740938544273376,
+ 0.9657480716705322,
+ 0.9762830138206482,
+ 0.9772912263870239,
+ 0.9772639274597168,
+ 0.9782102704048157,
+ 0.978047251701355,
+ 0.9793410897254944,
+ 0.9829915165901184,
+ 0.977618396282196,
+ 0.9794162511825562,
+ 0.9780856370925903,
+ 0.9457417726516724,
+ 0.9645279049873352,
+ 0.9745407104492188,
+ 0.9753058552742004,
+ 0.9752920866012573,
+ 0.975785493850708,
+ 0.9767656922340393,
+ 0.9786217212677002,
+ 0.9787272214889526,
+ 0.9782143831253052,
+ 0.978548526763916,
+ 0.9797430634498596,
+ 0.9783796668052673,
+ 0.9788793921470642,
+ 0.9774158000946045,
+ 0.977510929107666,
+ 0.9794684052467346,
+ 0.9799519181251526,
+ 0.9808859825134277,
+ 0.980858564376831,
+ 0.9806353449821472,
+ 0.9798208475112915,
+ 0.9801185131072998,
+ 0.9811465740203857,
+ 0.980260968208313,
+ 0.9816108345985413,
+ 0.9817655682563782,
+ 0.9815741777420044,
+ 0.9835969805717468,
+ 0.9819722175598145,
+ 0.982279360294342,
+ 0.9823079109191895,
+ 0.9823514819145203,
+ 0.9826764464378357,
+ 0.9827845096588135,
+ 0.9827767610549927,
+ 0.9799264669418335,
+ 0.9798104763031006,
+ 0.9804279208183289,
+ 0.9802714586257935,
+ 0.9820721745491028,
+ 0.9825060963630676,
+ 0.9860026836395264,
+ 0.9734170436859131,
+ 0.9770139455795288,
+ 0.9802051782608032,
+ 0.9802567958831787,
+ 0.9813519716262817,
+ 0.9743576645851135,
+ 0.9799386858940125,
+ 0.9786694049835205,
+ 0.9799676537513733,
+ 0.9810048937797546,
+ 0.980069637298584,
+ 0.9816738963127136,
+ 0.9833455681800842,
+ 0.981890082359314,
+ 0.9814645051956177,
+ 0.9829424619674683,
+ 0.9820736050605774,
+ 0.9824276566505432,
+ 0.9807500243186951,
+ 0.9814831018447876,
+ 0.9816710948944092
+ ],
+ "recall": [
+ 0.791861355304718,
+ 0.8451449275016785,
+ 0.8360539078712463,
+ 0.8644647002220154,
+ 0.8764495253562927,
+ 0.8856896758079529,
+ 0.886893630027771,
+ 0.8890292644500732,
+ 0.8897004127502441,
+ 0.8884459137916565,
+ 0.8893483281135559,
+ 0.8935720324516296,
+ 0.891123354434967,
+ 0.897046685218811,
+ 0.8985677361488342,
+ 0.904116690158844,
+ 0.9012035131454468,
+ 0.9133350253105164,
+ 0.9016810655593872,
+ 0.9030669331550598,
+ 0.9066722393035889,
+ 0.9079222679138184,
+ 0.9079533815383911,
+ 0.9120744466781616,
+ 0.9129979014396667,
+ 0.9152318835258484,
+ 0.9182695746421814,
+ 0.9195649027824402,
+ 0.9198611378669739,
+ 0.9220082759857178,
+ 0.8910340070724487,
+ 0.885751485824585,
+ 0.905456006526947,
+ 0.9070963263511658,
+ 0.9114834666252136,
+ 0.9191790819168091,
+ 0.9199527502059937,
+ 0.921348512172699,
+ 0.9224293231964111,
+ 0.9226288795471191,
+ 0.9153871536254883,
+ 0.9096200466156006,
+ 0.9198392629623413,
+ 0.9242357611656189,
+ 0.9300641417503357,
+ 0.9324789643287659,
+ 0.9314627051353455,
+ 0.9332801699638367,
+ 0.9335008263587952,
+ 0.9355972409248352,
+ 0.9391127824783325,
+ 0.8976917862892151,
+ 0.9181881546974182,
+ 0.9246246218681335,
+ 0.9293662905693054,
+ 0.9313563108444214,
+ 0.9316253662109375,
+ 0.9334909319877625,
+ 0.9344278573989868,
+ 0.9367935061454773,
+ 0.9377925395965576,
+ 0.9398112297058105,
+ 0.939002513885498,
+ 0.9445978999137878,
+ 0.9417274594306946,
+ 0.9408044219017029,
+ 0.9428611397743225,
+ 0.9436261057853699,
+ 0.9435579776763916,
+ 0.9460351467132568,
+ 0.9459415674209595,
+ 0.9465610384941101,
+ 0.947054386138916,
+ 0.944445788860321,
+ 0.9465948343276978,
+ 0.9477043747901917,
+ 0.9503015875816345,
+ 0.9510202407836914,
+ 0.9485911726951599,
+ 0.9491636753082275,
+ 0.9510773420333862,
+ 0.9523375630378723,
+ 0.9523199200630188,
+ 0.9516726732254028,
+ 0.9527785778045654,
+ 0.9506001472473145,
+ 0.9531168341636658,
+ 0.9553090929985046,
+ 0.9551175832748413,
+ 0.9558740854263306,
+ 0.954255223274231,
+ 0.9570053219795227,
+ 0.9550737738609314,
+ 0.9562565684318542,
+ 0.9573289752006531,
+ 0.9596896767616272,
+ 0.9582734704017639,
+ 0.9588923454284668,
+ 0.9603589773178101,
+ 0.9603280425071716,
+ 0.9656360745429993,
+ 0.9568237662315369,
+ 0.9465411305427551,
+ 0.9552591443061829,
+ 0.9520728588104248,
+ 0.9527888894081116,
+ 0.9479497075080872,
+ 0.9485549330711365,
+ 0.950512170791626,
+ 0.9550092220306396,
+ 0.9510374665260315,
+ 0.9581052660942078,
+ 0.9552298188209534,
+ 0.9569394588470459,
+ 0.9567636847496033,
+ 0.9556072354316711,
+ 0.9587507247924805,
+ 0.9555803537368774,
+ 0.9565052390098572,
+ 0.9600802659988403,
+ 0.9599339365959167,
+ 0.961341381072998,
+ 0.961948812007904,
+ 0.9583158493041992,
+ 0.9603102803230286,
+ 0.961682140827179,
+ 0.9610201716423035,
+ 0.9632683396339417,
+ 0.9627686738967896,
+ 0.9642438292503357,
+ 0.9615676999092102,
+ 0.9624653458595276,
+ 0.9617192149162292,
+ 0.9623971581459045,
+ 0.9632579684257507,
+ 0.9644113183021545,
+ 0.9640793204307556,
+ 0.9643028378486633,
+ 0.965067982673645,
+ 0.9640759229660034,
+ 0.9656127095222473,
+ 0.9689431190490723,
+ 0.9650319218635559,
+ 0.9649828672409058,
+ 0.9662241339683533,
+ 0.9663980603218079,
+ 0.9671016931533813,
+ 0.9658815860748291,
+ 0.9681800603866577,
+ 0.967562198638916,
+ 0.9672921895980835,
+ 0.9681373238563538,
+ 0.9681686758995056,
+ 0.967867910861969,
+ 0.9687545299530029,
+ 0.9694454669952393,
+ 0.9693561792373657,
+ 0.9714493751525879,
+ 0.9680829048156738,
+ 0.9692250490188599,
+ 0.9661319851875305,
+ 0.9665122628211975,
+ 0.9688233137130737,
+ 0.9710770845413208,
+ 0.9714246392250061,
+ 0.9727378487586975,
+ 0.9704562425613403,
+ 0.9707130789756775,
+ 0.9669188857078552,
+ 0.9709378480911255,
+ 0.9702185988426208,
+ 0.9736852645874023,
+ 0.9710268974304199,
+ 0.9726718664169312,
+ 0.9724496603012085,
+ 0.9668371081352234,
+ 0.9712908267974854,
+ 0.9726231694221497,
+ 0.9642851948738098,
+ 0.9676161408424377,
+ 0.9709209203720093,
+ 0.9643153548240662,
+ 0.9681193232536316,
+ 0.9739947319030762,
+ 0.9700707793235779,
+ 0.9738551378250122,
+ 0.971240222454071,
+ 0.9730367660522461,
+ 0.9735251665115356,
+ 0.973804771900177,
+ 0.9751293659210205,
+ 0.9741354584693909,
+ 0.9756639003753662,
+ 0.9757745265960693,
+ 0.974200963973999,
+ 0.9770787954330444,
+ 0.9764552116394043,
+ 0.9756947755813599,
+ 0.9650297164916992,
+ 0.9684521555900574,
+ 0.9673844575881958,
+ 0.972891092300415,
+ 0.9744333624839783,
+ 0.9750074148178101,
+ 0.9752647280693054,
+ 0.9729223251342773,
+ 0.9758098721504211,
+ 0.9761912822723389,
+ 0.9763118028640747,
+ 0.9769499897956848,
+ 0.9794024229049683,
+ 0.9793176054954529,
+ 0.9737824201583862,
+ 0.9756699800491333,
+ 0.9742213487625122,
+ 0.9722992181777954,
+ 0.9717274904251099,
+ 0.9682989716529846,
+ 0.9701378345489502,
+ 0.9694706201553345,
+ 0.9723451733589172,
+ 0.9731424450874329,
+ 0.9645086526870728,
+ 0.9728704690933228,
+ 0.9672322869300842,
+ 0.9669365882873535,
+ 0.9674783945083618,
+ 0.9659968018531799,
+ 0.9690768122673035,
+ 0.9695566892623901,
+ 0.9709073305130005,
+ 0.9731854796409607,
+ 0.9721854329109192,
+ 0.9730652570724487,
+ 0.9733920693397522,
+ 0.9740149974822998,
+ 0.9736261367797852,
+ 0.9743325114250183,
+ 0.9750107526779175,
+ 0.9701613783836365,
+ 0.9741973876953125,
+ 0.974717915058136,
+ 0.9756714701652527,
+ 0.9747783541679382,
+ 0.9748955368995667,
+ 0.9760963916778564,
+ 0.9734497666358948,
+ 0.975731372833252,
+ 0.97313392162323,
+ 0.9752202033996582,
+ 0.9759657979011536,
+ 0.9775606393814087,
+ 0.975492537021637,
+ 0.9772050976753235,
+ 0.9768574237823486,
+ 0.9791225790977478,
+ 0.9763096570968628,
+ 0.9775936603546143,
+ 0.9781891107559204,
+ 0.9793063402175903,
+ 0.9774335622787476,
+ 0.9769541621208191,
+ 0.9790760278701782,
+ 0.9690549373626709,
+ 0.9751819968223572,
+ 0.9764674305915833,
+ 0.9760426878929138,
+ 0.9774177670478821,
+ 0.9766298532485962,
+ 0.9790302515029907,
+ 0.9758355021476746,
+ 0.9779002666473389,
+ 0.9785527586936951,
+ 0.9777949452400208,
+ 0.9788627028465271,
+ 0.9708422422409058,
+ 0.9752458930015564,
+ 0.9690341353416443,
+ 0.9516607522964478,
+ 0.9637950658798218,
+ 0.9636418223381042,
+ 0.9688300490379333,
+ 0.9696590900421143,
+ 0.9694294929504395,
+ 0.971558690071106,
+ 0.9723370671272278,
+ 0.9729266166687012,
+ 0.9740192890167236,
+ 0.9741011261940002,
+ 0.9779934287071228,
+ 0.9755014181137085,
+ 0.9758363366127014,
+ 0.9757565259933472,
+ 0.9730213284492493,
+ 0.976201057434082,
+ 0.9763962030410767,
+ 0.9774160981178284,
+ 0.9763848185539246,
+ 0.9747228026390076,
+ 0.9764521718025208,
+ 0.9752063155174255,
+ 0.9774546027183533,
+ 0.9776479005813599,
+ 0.9784721732139587,
+ 0.9786161780357361,
+ 0.978244960308075,
+ 0.9769057035446167,
+ 0.9791543483734131,
+ 0.9713692665100098,
+ 0.9771057367324829,
+ 0.9656568765640259,
+ 0.976421594619751,
+ 0.9769189953804016,
+ 0.9776700139045715,
+ 0.9783620238304138,
+ 0.9783968925476074,
+ 0.9808138608932495,
+ 0.9787189960479736,
+ 0.9789877533912659,
+ 0.9796712398529053,
+ 0.9797583818435669,
+ 0.9797834753990173,
+ 0.9798057675361633,
+ 0.9808721542358398,
+ 0.9801729321479797,
+ 0.9800459146499634,
+ 0.9818342328071594,
+ 0.9800011515617371,
+ 0.9811092019081116,
+ 0.9812566637992859,
+ 0.9801689386367798,
+ 0.9807889461517334,
+ 0.9815011620521545,
+ 0.9724181890487671,
+ 0.9728742241859436,
+ 0.9720596671104431,
+ 0.9741119146347046,
+ 0.9755633473396301,
+ 0.9736440181732178,
+ 0.9756592512130737,
+ 0.9758150577545166,
+ 0.9767378568649292,
+ 0.9772686958312988,
+ 0.9770384430885315,
+ 0.9786340594291687,
+ 0.9781789183616638,
+ 0.9786956310272217,
+ 0.9774749279022217,
+ 0.9791203737258911,
+ 0.979654848575592,
+ 0.9737077355384827,
+ 0.979299783706665,
+ 0.979436993598938,
+ 0.9604448676109314,
+ 0.9702677726745605,
+ 0.9712108969688416,
+ 0.9761053323745728,
+ 0.9737813472747803,
+ 0.9762065410614014,
+ 0.9765533208847046,
+ 0.9772399663925171,
+ 0.9777761101722717,
+ 0.9773626327514648,
+ 0.9799520969390869,
+ 0.9813448190689087,
+ 0.9787777066230774,
+ 0.9791331887245178,
+ 0.9794602394104004,
+ 0.9813597798347473,
+ 0.9810588359832764,
+ 0.9773743152618408,
+ 0.976746678352356,
+ 0.9784935116767883,
+ 0.9773215055465698,
+ 0.9610670208930969,
+ 0.973602294921875,
+ 0.973552942276001,
+ 0.9748286008834839,
+ 0.9766324162483215,
+ 0.9799429178237915,
+ 0.9767052531242371,
+ 0.9771280884742737,
+ 0.9768232703208923,
+ 0.977131187915802,
+ 0.9742915034294128,
+ 0.9781007766723633,
+ 0.9782840609550476,
+ 0.9787495732307434,
+ 0.9791613817214966,
+ 0.9780554175376892,
+ 0.9801104068756104,
+ 0.979373574256897,
+ 0.9821428060531616,
+ 0.9800913333892822,
+ 0.9793018698692322,
+ 0.9797273874282837,
+ 0.9826022982597351,
+ 0.9796127080917358,
+ 0.9782849550247192,
+ 0.9827302694320679,
+ 0.9816903471946716,
+ 0.9800970554351807,
+ 0.9793035387992859,
+ 0.9830003380775452,
+ 0.9796827435493469,
+ 0.9808118939399719,
+ 0.9785438776016235,
+ 0.9761611223220825,
+ 0.9786279201507568,
+ 0.9751942157745361,
+ 0.9798658490180969,
+ 0.9793126583099365,
+ 0.9795118570327759,
+ 0.9801921248435974,
+ 0.9600467085838318,
+ 0.9629004001617432,
+ 0.9759656190872192,
+ 0.9701969623565674,
+ 0.972531259059906,
+ 0.9746580123901367,
+ 0.975263237953186,
+ 0.9770999550819397,
+ 0.9760912656784058,
+ 0.9698919057846069,
+ 0.9785098433494568,
+ 0.9758208394050598,
+ 0.9791128635406494,
+ 0.9835783243179321,
+ 0.9748319387435913,
+ 0.9716911315917969,
+ 0.9732897877693176,
+ 0.9752651453018188,
+ 0.9762004613876343,
+ 0.9760403037071228,
+ 0.9754682183265686,
+ 0.9757618308067322,
+ 0.9773507118225098,
+ 0.977741003036499,
+ 0.9767034649848938,
+ 0.9789731502532959,
+ 0.9788464903831482,
+ 0.9802247285842896,
+ 0.9789708852767944,
+ 0.9785082936286926,
+ 0.9784370064735413,
+ 0.9782934784889221,
+ 0.9785864949226379,
+ 0.9796760082244873,
+ 0.9803999066352844,
+ 0.9799817204475403,
+ 0.9791824221611023,
+ 0.9809419512748718,
+ 0.9801437258720398,
+ 0.9804273247718811,
+ 0.9808862805366516,
+ 0.9786785244941711,
+ 0.9803816080093384,
+ 0.9807310104370117,
+ 0.9811263680458069,
+ 0.9809145331382751,
+ 0.9808802008628845,
+ 0.9804469347000122,
+ 0.981417179107666,
+ 0.9807339310646057,
+ 0.9795430898666382,
+ 0.9809128642082214,
+ 0.982026219367981,
+ 0.9813324809074402,
+ 0.9819055795669556,
+ 0.9553102850914001,
+ 0.9714292883872986,
+ 0.9767655730247498,
+ 0.9774090051651001,
+ 0.9792118668556213,
+ 0.9792583584785461,
+ 0.9822831153869629,
+ 0.9703272581100464,
+ 0.9751678705215454,
+ 0.9770734310150146,
+ 0.9776707887649536,
+ 0.9797150492668152,
+ 0.9790826439857483,
+ 0.9780145287513733,
+ 0.9803314805030823,
+ 0.9812318682670593,
+ 0.9790366888046265,
+ 0.9809558987617493,
+ 0.9810067415237427,
+ 0.981852650642395,
+ 0.9823798537254333,
+ 0.9825533628463745
+ ],
+ "val_loss": [
+ 0.28723815083503723,
+ 0.2713486850261688,
+ 0.23316213488578796,
+ 0.21659114956855774,
+ 0.20152230560779572,
+ 0.19967979192733765,
+ 0.19528080523014069,
+ 0.1881239116191864,
+ 0.1821192353963852,
+ 0.17346297204494476,
+ 0.1740366667509079,
+ 0.17089757323265076,
+ 0.16789646446704865,
+ 0.1660037338733673,
+ 0.16604383289813995,
+ 0.17595873773097992,
+ 0.1708299219608307,
+ 0.18093723058700562,
+ 0.16202183067798615,
+ 0.16549760103225708,
+ 0.16674426198005676,
+ 0.1612197458744049,
+ 0.16146968305110931,
+ 0.1589117795228958,
+ 0.16076944768428802,
+ 0.1649247258901596,
+ 0.16818247735500336,
+ 0.17020373046398163,
+ 0.17034289240837097,
+ 0.172881081700325,
+ 0.17018266022205353,
+ 0.17107126116752625,
+ 0.17065449059009552,
+ 0.17375759780406952,
+ 0.18439385294914246,
+ 0.28123483061790466,
+ 0.18936650454998016,
+ 0.18584774434566498,
+ 0.191576287150383,
+ 0.19643668830394745,
+ 0.6989830136299133,
+ 0.7043436169624329,
+ 0.39692726731300354,
+ 0.21512266993522644,
+ 0.18475396931171417,
+ 0.18470022082328796,
+ 0.1896016001701355,
+ 0.1995842456817627,
+ 0.20438623428344727,
+ 0.2045293003320694,
+ 0.20099927484989166,
+ 0.16808313131332397,
+ 0.1720547080039978,
+ 0.17455346882343292,
+ 0.17807002365589142,
+ 0.1794194132089615,
+ 0.1817426085472107,
+ 0.18469466269016266,
+ 0.18639053404331207,
+ 0.18814800679683685,
+ 0.1905999630689621,
+ 0.19293445348739624,
+ 0.19491058588027954,
+ 0.19337712228298187,
+ 0.19332274794578552,
+ 0.1962611973285675,
+ 0.19995097815990448,
+ 0.20085245370864868,
+ 0.20314878225326538,
+ 0.19838808476924896,
+ 0.19687509536743164,
+ 0.204046830534935,
+ 0.20114660263061523,
+ 0.19768579304218292,
+ 0.20237891376018524,
+ 0.2069508284330368,
+ 0.19951966404914856,
+ 0.19902941584587097,
+ 0.20256203413009644,
+ 0.20888705551624298,
+ 0.21395546197891235,
+ 0.2150646299123764,
+ 0.21527113020420074,
+ 0.21789665520191193,
+ 0.21774950623512268,
+ 0.21902601420879364,
+ 0.22202743589878082,
+ 0.22601786255836487,
+ 0.22464175522327423,
+ 0.22582165896892548,
+ 0.21727927029132843,
+ 0.20705226063728333,
+ 0.21868649125099182,
+ 0.22494469583034515,
+ 0.23110167682170868,
+ 0.21682409942150116,
+ 0.22849085927009583,
+ 0.2313670963048935,
+ 0.23593372106552124,
+ 0.2373199164867401,
+ 0.4533359110355377,
+ 0.20986557006835938,
+ 0.19883373379707336,
+ 0.1990339756011963,
+ 0.20048026740550995,
+ 0.2039482146501541,
+ 0.21253350377082825,
+ 0.21681621670722961,
+ 0.22275234758853912,
+ 0.20872315764427185,
+ 0.2050584852695465,
+ 0.21987168490886688,
+ 0.23898328840732574,
+ 0.29479342699050903,
+ 0.3554593622684479,
+ 0.2502932548522949,
+ 0.2700035274028778,
+ 0.33491018414497375,
+ 0.2824503183364868,
+ 0.29658064246177673,
+ 0.31978747248649597,
+ 0.3138459026813507,
+ 0.349152535200119,
+ 0.29236090183258057,
+ 0.31074059009552,
+ 0.267335444688797,
+ 0.281875878572464,
+ 0.26835721731185913,
+ 0.2765218913555145,
+ 0.2490549087524414,
+ 0.270651638507843,
+ 0.2674177289009094,
+ 0.2566213607788086,
+ 0.25731098651885986,
+ 0.2570015788078308,
+ 0.25581344962120056,
+ 0.2553556263446808,
+ 0.2602108418941498,
+ 0.256712406873703,
+ 0.26190245151519775,
+ 0.26316896080970764,
+ 0.2578590214252472,
+ 0.26275983452796936,
+ 0.2636365592479706,
+ 0.26617926359176636,
+ 0.2679619789123535,
+ 0.2680908441543579,
+ 0.26747429370880127,
+ 0.2679744064807892,
+ 0.269931823015213,
+ 0.2670978009700775,
+ 0.2718101739883423,
+ 0.27347344160079956,
+ 0.27368682622909546,
+ 0.27443116903305054,
+ 0.2800472378730774,
+ 0.2824190855026245,
+ 0.2883802056312561,
+ 0.2995620369911194,
+ 0.3005017042160034,
+ 0.3022603988647461,
+ 0.30273953080177307,
+ 0.2985783517360687,
+ 0.2965547442436218,
+ 0.2972632646560669,
+ 0.2994842529296875,
+ 0.2960124611854553,
+ 0.2643485963344574,
+ 0.27351850271224976,
+ 0.25582224130630493,
+ 0.2798594832420349,
+ 0.27532312273979187,
+ 0.26906818151474,
+ 0.27401596307754517,
+ 0.2865924537181854,
+ 0.26777151226997375,
+ 0.25301873683929443,
+ 0.2687220871448517,
+ 0.33844470977783203,
+ 0.2465686947107315,
+ 0.23854368925094604,
+ 0.23741523921489716,
+ 0.2444889098405838,
+ 0.2561025321483612,
+ 0.2590726613998413,
+ 0.2630075216293335,
+ 0.269187867641449,
+ 0.2795838713645935,
+ 0.2845974564552307,
+ 0.28783342242240906,
+ 0.28345707058906555,
+ 0.2835446298122406,
+ 0.28521379828453064,
+ 0.2847643494606018,
+ 0.28264346718788147,
+ 0.2821420133113861,
+ 0.2802213728427887,
+ 0.28383877873420715,
+ 0.3033846616744995,
+ 0.27275991439819336,
+ 0.27391302585601807,
+ 0.2827366888523102,
+ 0.2799161374568939,
+ 0.28558021783828735,
+ 0.2886905074119568,
+ 0.2805515229701996,
+ 0.2876611351966858,
+ 0.2848869860172272,
+ 0.2906288206577301,
+ 0.27950969338417053,
+ 0.254498690366745,
+ 0.24704022705554962,
+ 0.2561323940753937,
+ 0.26700064539909363,
+ 0.24114003777503967,
+ 0.2473776936531067,
+ 0.2354305535554886,
+ 0.23682336509227753,
+ 0.21703208982944489,
+ 0.21673893928527832,
+ 0.2144777625799179,
+ 0.19983965158462524,
+ 0.18079981207847595,
+ 0.21088142693042755,
+ 0.19281452894210815,
+ 0.20026083290576935,
+ 0.21254859864711761,
+ 0.22024133801460266,
+ 0.2252091020345688,
+ 0.2187010645866394,
+ 0.2269616574048996,
+ 0.22800034284591675,
+ 0.2363997995853424,
+ 0.23668085038661957,
+ 0.23937562108039856,
+ 0.23971028625965118,
+ 0.24257995188236237,
+ 0.24286648631095886,
+ 0.24594220519065857,
+ 0.24318280816078186,
+ 0.23669882118701935,
+ 0.24397698044776917,
+ 0.24627374112606049,
+ 0.2619628608226776,
+ 0.25442594289779663,
+ 0.2580946385860443,
+ 0.28162869811058044,
+ 0.2419835925102234,
+ 0.24677637219429016,
+ 0.24248021841049194,
+ 0.24594305455684662,
+ 0.25479814410209656,
+ 0.25116685032844543,
+ 0.2514200508594513,
+ 0.2554737627506256,
+ 0.24422712624073029,
+ 0.24993255734443665,
+ 0.24904866516590118,
+ 0.25754067301750183,
+ 0.2438945472240448,
+ 0.2556094229221344,
+ 0.256924569606781,
+ 0.24302861094474792,
+ 0.19861286878585815,
+ 0.19839613139629364,
+ 0.2059561163187027,
+ 0.2107318490743637,
+ 0.2172657996416092,
+ 0.22223681211471558,
+ 0.22834354639053345,
+ 0.24637369811534882,
+ 0.22940300405025482,
+ 0.23817867040634155,
+ 0.2395147681236267,
+ 0.24258115887641907,
+ 0.2501608729362488,
+ 0.25339969992637634,
+ 0.21451711654663086,
+ 0.17311611771583557,
+ 0.16781897842884064,
+ 0.1619681566953659,
+ 0.1635843962430954,
+ 0.16744934022426605,
+ 0.17401142418384552,
+ 0.18347276747226715,
+ 0.19007667899131775,
+ 0.1960650086402893,
+ 0.2002505362033844,
+ 0.20441405475139618,
+ 0.20215874910354614,
+ 0.20224016904830933,
+ 0.20779606699943542,
+ 0.21010547876358032,
+ 0.21457569301128387,
+ 0.2174983024597168,
+ 0.2231486290693283,
+ 0.22089047729969025,
+ 0.22513629496097565,
+ 0.22668702900409698,
+ 0.22853055596351624,
+ 0.23308581113815308,
+ 0.23511795699596405,
+ 0.23914246261119843,
+ 0.24176160991191864,
+ 0.2509233355522156,
+ 0.244533509016037,
+ 0.2494301050901413,
+ 0.2527102530002594,
+ 0.2509970963001251,
+ 0.23927585780620575,
+ 0.2555169463157654,
+ 0.24370647966861725,
+ 0.24824310839176178,
+ 0.24955306947231293,
+ 0.25017765164375305,
+ 0.24821235239505768,
+ 0.2448243349790573,
+ 0.2462622970342636,
+ 0.24949535727500916,
+ 0.25326254963874817,
+ 0.25667038559913635,
+ 0.25821465253829956,
+ 0.26118817925453186,
+ 0.2644621431827545,
+ 0.2672300338745117,
+ 0.2668937146663666,
+ 0.2716364860534668,
+ 0.2697041928768158,
+ 0.27032187581062317,
+ 0.27397218346595764,
+ 0.2758488059043884,
+ 0.28225117921829224,
+ 0.2826656997203827,
+ 0.3523707389831543,
+ 0.35616227984428406,
+ 0.3610374331474304,
+ 0.34478408098220825,
+ 0.3169664740562439,
+ 0.325067400932312,
+ 0.32318148016929626,
+ 0.3198719620704651,
+ 0.31806430220603943,
+ 0.3153974711894989,
+ 0.3156879246234894,
+ 0.31431904435157776,
+ 0.3098563551902771,
+ 0.3104737102985382,
+ 0.31270310282707214,
+ 0.3181253969669342,
+ 0.3163810670375824,
+ 0.3176153600215912,
+ 0.3160388171672821,
+ 0.30148062109947205,
+ 0.39275458455085754,
+ 0.3649538457393646,
+ 0.3534722626209259,
+ 0.3403719663619995,
+ 0.33192750811576843,
+ 0.326800674200058,
+ 0.3222004771232605,
+ 0.3207232356071472,
+ 0.32102692127227783,
+ 0.3194082975387573,
+ 0.33628493547439575,
+ 0.3176187574863434,
+ 0.3192712068557739,
+ 0.31775015592575073,
+ 0.31960445642471313,
+ 0.3243445158004761,
+ 0.31762775778770447,
+ 0.3086387515068054,
+ 0.30278441309928894,
+ 0.3129657804965973,
+ 0.3011714220046997,
+ 0.32860052585601807,
+ 0.3342500627040863,
+ 0.3335627615451813,
+ 0.32909467816352844,
+ 0.3580184280872345,
+ 0.3430590331554413,
+ 0.32331663370132446,
+ 0.3198379874229431,
+ 0.31436392664909363,
+ 0.3113187551498413,
+ 0.3089807331562042,
+ 0.30715039372444153,
+ 0.3089626133441925,
+ 0.3062823414802551,
+ 0.30327120423316956,
+ 0.3016895353794098,
+ 0.3001919388771057,
+ 0.30333200097084045,
+ 0.2917383909225464,
+ 0.2956298589706421,
+ 0.29643189907073975,
+ 0.30130383372306824,
+ 0.28969016671180725,
+ 0.299941748380661,
+ 0.29527848958969116,
+ 0.30097830295562744,
+ 0.3061491549015045,
+ 0.29706084728240967,
+ 0.31153690814971924,
+ 0.2997742295265198,
+ 0.2966189980506897,
+ 0.29177623987197876,
+ 0.3229901194572449,
+ 0.3082697093486786,
+ 0.30902624130249023,
+ 0.3128173053264618,
+ 0.28913018107414246,
+ 0.29392433166503906,
+ 0.29433149099349976,
+ 0.2970201075077057,
+ 0.24357981979846954,
+ 0.2757555842399597,
+ 0.25999417901039124,
+ 0.2673036456108093,
+ 0.27445095777511597,
+ 0.2781122028827667,
+ 0.2836487889289856,
+ 0.28633299469947815,
+ 0.2873220145702362,
+ 0.2814802825450897,
+ 0.25818365812301636,
+ 0.28027209639549255,
+ 0.2846464514732361,
+ 0.22045518457889557,
+ 0.2228807657957077,
+ 0.22831334173679352,
+ 0.23165997862815857,
+ 0.23505640029907227,
+ 0.23972679674625397,
+ 0.24111683666706085,
+ 0.2465180903673172,
+ 0.25158917903900146,
+ 0.2559452950954437,
+ 0.25975653529167175,
+ 0.2616535425186157,
+ 0.26229920983314514,
+ 0.26474112272262573,
+ 0.2715058922767639,
+ 0.2710859477519989,
+ 0.276827335357666,
+ 0.28341811895370483,
+ 0.28683561086654663,
+ 0.2913525402545929,
+ 0.2918616831302643,
+ 0.29444780945777893,
+ 0.29850977659225464,
+ 0.3032839596271515,
+ 0.3009886145591736,
+ 0.30314281582832336,
+ 0.3034873604774475,
+ 0.3089466094970703,
+ 0.3075098693370819,
+ 0.31456753611564636,
+ 0.3177722096443176,
+ 0.31999513506889343,
+ 0.3137664198875427,
+ 0.32393699884414673,
+ 0.31554022431373596,
+ 0.31903356313705444,
+ 0.35382118821144104,
+ 0.3430420756340027,
+ 0.32786762714385986,
+ 0.32548537850379944,
+ 0.32288435101509094,
+ 0.3248937427997589,
+ 0.3364331126213074,
+ 0.32254448533058167,
+ 0.3199025094509125,
+ 0.31521788239479065,
+ 0.3154824376106262,
+ 0.31081122159957886,
+ 0.3064108192920685,
+ 0.29099735617637634,
+ 0.29672113060951233,
+ 0.2918401062488556,
+ 0.29029548168182373,
+ 0.2933703660964966,
+ 0.2857241630554199,
+ 0.28566399216651917,
+ 0.2848254442214966,
+ 0.28084006905555725,
+ 0.27759119868278503,
+ 0.28184911608695984,
+ 0.28439953923225403,
+ 0.2827478349208832,
+ 0.2888249158859253,
+ 0.28556081652641296
+ ],
+ "val_dice_coef": [
+ 0.7766007781028748,
+ 0.7362868189811707,
+ 0.7365289330482483,
+ 0.736765444278717,
+ 0.7391619086265564,
+ 0.7354344725608826,
+ 0.7399694323539734,
+ 0.7447896599769592,
+ 0.7554518580436707,
+ 0.7736577391624451,
+ 0.7764828205108643,
+ 0.7869715094566345,
+ 0.7908278703689575,
+ 0.8058087825775146,
+ 0.8143413662910461,
+ 0.8216170072555542,
+ 0.8258571028709412,
+ 0.8172253966331482,
+ 0.8246889114379883,
+ 0.8262502551078796,
+ 0.8275225162506104,
+ 0.8288997411727905,
+ 0.8298549056053162,
+ 0.8314856886863708,
+ 0.8321593999862671,
+ 0.8323621153831482,
+ 0.833370566368103,
+ 0.834065854549408,
+ 0.8344595432281494,
+ 0.8343856334686279,
+ 0.8009603023529053,
+ 0.8194039463996887,
+ 0.8238095641136169,
+ 0.8264465928077698,
+ 0.8254995346069336,
+ 0.6870414614677429,
+ 0.8023579716682434,
+ 0.81545090675354,
+ 0.8230668306350708,
+ 0.8244038820266724,
+ 0.2739965617656708,
+ 0.2708752453327179,
+ 0.5725690722465515,
+ 0.7653201818466187,
+ 0.8189557194709778,
+ 0.820512592792511,
+ 0.821704089641571,
+ 0.8296846151351929,
+ 0.8308324217796326,
+ 0.8331705927848816,
+ 0.8332914710044861,
+ 0.8199408054351807,
+ 0.8255106210708618,
+ 0.8291324377059937,
+ 0.8306165933609009,
+ 0.8324189782142639,
+ 0.8333595395088196,
+ 0.8329381346702576,
+ 0.8340929746627808,
+ 0.8348809480667114,
+ 0.8351688385009766,
+ 0.8362818360328674,
+ 0.8363726735115051,
+ 0.8385686278343201,
+ 0.839104413986206,
+ 0.8386164903640747,
+ 0.8377727270126343,
+ 0.8374584913253784,
+ 0.8369852304458618,
+ 0.8386901617050171,
+ 0.8384276032447815,
+ 0.8365789651870728,
+ 0.8369478583335876,
+ 0.8361538648605347,
+ 0.8366899490356445,
+ 0.836506187915802,
+ 0.8378056287765503,
+ 0.8368849158287048,
+ 0.8375480771064758,
+ 0.8367794156074524,
+ 0.8366787433624268,
+ 0.8369461894035339,
+ 0.8369993567466736,
+ 0.8365828394889832,
+ 0.836992621421814,
+ 0.8358801007270813,
+ 0.8362458944320679,
+ 0.8364722728729248,
+ 0.8367875218391418,
+ 0.8361663222312927,
+ 0.837709367275238,
+ 0.8383733034133911,
+ 0.837929904460907,
+ 0.8377741575241089,
+ 0.8372476100921631,
+ 0.8400737047195435,
+ 0.8390684127807617,
+ 0.83942711353302,
+ 0.8388166427612305,
+ 0.8388614654541016,
+ 0.5899927616119385,
+ 0.8362930417060852,
+ 0.8417453169822693,
+ 0.8424978852272034,
+ 0.8421810865402222,
+ 0.8427466154098511,
+ 0.8386762142181396,
+ 0.8371509909629822,
+ 0.833957314491272,
+ 0.8364364504814148,
+ 0.8421381115913391,
+ 0.8345608115196228,
+ 0.8007285594940186,
+ 0.7260335087776184,
+ 0.6692654490470886,
+ 0.7947330474853516,
+ 0.7512447237968445,
+ 0.6925077438354492,
+ 0.745521605014801,
+ 0.7343494296073914,
+ 0.7150691151618958,
+ 0.7198730111122131,
+ 0.6902240514755249,
+ 0.7437922954559326,
+ 0.7274830937385559,
+ 0.7680156826972961,
+ 0.7545132040977478,
+ 0.768580436706543,
+ 0.7604827284812927,
+ 0.7898576855659485,
+ 0.7694604992866516,
+ 0.7757206559181213,
+ 0.7929953932762146,
+ 0.7918580174446106,
+ 0.7945917248725891,
+ 0.803205132484436,
+ 0.8119467496871948,
+ 0.8164100646972656,
+ 0.8216281533241272,
+ 0.8251609802246094,
+ 0.8262266516685486,
+ 0.8401311039924622,
+ 0.8315710425376892,
+ 0.8277207016944885,
+ 0.8256680369377136,
+ 0.8304272890090942,
+ 0.8297643065452576,
+ 0.8305029273033142,
+ 0.833788275718689,
+ 0.8272045850753784,
+ 0.8259381651878357,
+ 0.8239844441413879,
+ 0.8249673843383789,
+ 0.823042094707489,
+ 0.8268076777458191,
+ 0.8339473605155945,
+ 0.8356831073760986,
+ 0.8316064476966858,
+ 0.8334366083145142,
+ 0.8356269598007202,
+ 0.8349707722663879,
+ 0.8357625603675842,
+ 0.836354672908783,
+ 0.8350753784179688,
+ 0.8341637253761292,
+ 0.835790753364563,
+ 0.8393711447715759,
+ 0.8449829816818237,
+ 0.8419803977012634,
+ 0.8453216552734375,
+ 0.8424800634384155,
+ 0.8430929780006409,
+ 0.8414286375045776,
+ 0.825545072555542,
+ 0.8366214036941528,
+ 0.8449114561080933,
+ 0.8470876216888428,
+ 0.8456646203994751,
+ 0.8375294208526611,
+ 0.8552544116973877,
+ 0.8545020222663879,
+ 0.8554497361183167,
+ 0.8545476794242859,
+ 0.8512993454933167,
+ 0.8510252237319946,
+ 0.8507089018821716,
+ 0.8489977121353149,
+ 0.8457096219062805,
+ 0.8413915634155273,
+ 0.8370864987373352,
+ 0.8341518044471741,
+ 0.8336361646652222,
+ 0.8339350819587708,
+ 0.8373669385910034,
+ 0.8392994403839111,
+ 0.839483380317688,
+ 0.8332099914550781,
+ 0.8350664973258972,
+ 0.8205387592315674,
+ 0.8416454195976257,
+ 0.8428033590316772,
+ 0.8411937952041626,
+ 0.8427980542182922,
+ 0.8404627442359924,
+ 0.842330276966095,
+ 0.8395564556121826,
+ 0.8414638638496399,
+ 0.8424361944198608,
+ 0.8446895480155945,
+ 0.846848726272583,
+ 0.8489895462989807,
+ 0.8482503294944763,
+ 0.8459113240242004,
+ 0.844142735004425,
+ 0.8484106659889221,
+ 0.8471189737319946,
+ 0.845054566860199,
+ 0.8444048166275024,
+ 0.8466100096702576,
+ 0.8477064967155457,
+ 0.8483533263206482,
+ 0.8503214716911316,
+ 0.8509876728057861,
+ 0.8316774964332581,
+ 0.8370324373245239,
+ 0.839336097240448,
+ 0.8408005237579346,
+ 0.8428125977516174,
+ 0.8437373638153076,
+ 0.8455855846405029,
+ 0.8456382155418396,
+ 0.8462637662887573,
+ 0.8460706472396851,
+ 0.8462703227996826,
+ 0.8461295366287231,
+ 0.8465862274169922,
+ 0.8466053605079651,
+ 0.846900999546051,
+ 0.8466494083404541,
+ 0.845362663269043,
+ 0.8469746112823486,
+ 0.846789538860321,
+ 0.8473795056343079,
+ 0.8461935520172119,
+ 0.8466159105300903,
+ 0.8463333249092102,
+ 0.8458000421524048,
+ 0.8505790829658508,
+ 0.8503738641738892,
+ 0.8476802110671997,
+ 0.845830500125885,
+ 0.8487781882286072,
+ 0.8485366106033325,
+ 0.8473630547523499,
+ 0.8498547077178955,
+ 0.8499366641044617,
+ 0.8500738739967346,
+ 0.8502128720283508,
+ 0.8495468497276306,
+ 0.8469513654708862,
+ 0.8481175899505615,
+ 0.849429726600647,
+ 0.849066972732544,
+ 0.8548311591148376,
+ 0.8550606369972229,
+ 0.8563703894615173,
+ 0.8567954897880554,
+ 0.856242835521698,
+ 0.8552171587944031,
+ 0.8527343273162842,
+ 0.8532533049583435,
+ 0.8556985259056091,
+ 0.8543787598609924,
+ 0.8551021218299866,
+ 0.8552600741386414,
+ 0.8515934348106384,
+ 0.8519724011421204,
+ 0.8135335445404053,
+ 0.8389540910720825,
+ 0.8461066484451294,
+ 0.8526746034622192,
+ 0.8554375171661377,
+ 0.8559949994087219,
+ 0.8570294976234436,
+ 0.857523500919342,
+ 0.8580809831619263,
+ 0.8586397171020508,
+ 0.8586057424545288,
+ 0.858701765537262,
+ 0.8587692975997925,
+ 0.859142541885376,
+ 0.8580973148345947,
+ 0.8573670983314514,
+ 0.8562242388725281,
+ 0.8564314246177673,
+ 0.8561164736747742,
+ 0.8571146130561829,
+ 0.8570607304573059,
+ 0.8561781048774719,
+ 0.8559383749961853,
+ 0.8557175397872925,
+ 0.8561850786209106,
+ 0.8556377291679382,
+ 0.8554741740226746,
+ 0.8555220365524292,
+ 0.8559049963951111,
+ 0.8547510504722595,
+ 0.8522539734840393,
+ 0.8539879322052002,
+ 0.852405846118927,
+ 0.849246621131897,
+ 0.8530959486961365,
+ 0.8528076410293579,
+ 0.852859616279602,
+ 0.8532447814941406,
+ 0.8533909916877747,
+ 0.8527594804763794,
+ 0.8533093333244324,
+ 0.8533316850662231,
+ 0.853492796421051,
+ 0.8531711101531982,
+ 0.8533493280410767,
+ 0.8533123135566711,
+ 0.8528207540512085,
+ 0.852786660194397,
+ 0.8533142805099487,
+ 0.8515560626983643,
+ 0.8524121046066284,
+ 0.8523905277252197,
+ 0.8523167371749878,
+ 0.8519017100334167,
+ 0.8507035970687866,
+ 0.8511326909065247,
+ 0.839489221572876,
+ 0.8289598822593689,
+ 0.8272483348846436,
+ 0.8299774527549744,
+ 0.8335668444633484,
+ 0.8329368829727173,
+ 0.8338814377784729,
+ 0.8352221846580505,
+ 0.8364477753639221,
+ 0.8375290036201477,
+ 0.8377435803413391,
+ 0.8386745452880859,
+ 0.8395848870277405,
+ 0.8398758769035339,
+ 0.8401604890823364,
+ 0.8399484753608704,
+ 0.8398151397705078,
+ 0.8405882716178894,
+ 0.8409283757209778,
+ 0.8175678253173828,
+ 0.8111531138420105,
+ 0.822675347328186,
+ 0.8253350853919983,
+ 0.8251060843467712,
+ 0.8288128972053528,
+ 0.8320625424385071,
+ 0.8333511352539062,
+ 0.8346990346908569,
+ 0.8357964754104614,
+ 0.8368164300918579,
+ 0.8353255391120911,
+ 0.8381621241569519,
+ 0.838522732257843,
+ 0.8384892344474792,
+ 0.8390708565711975,
+ 0.838234543800354,
+ 0.8389724493026733,
+ 0.8399551510810852,
+ 0.8399912714958191,
+ 0.8401727080345154,
+ 0.8411318063735962,
+ 0.8280410766601562,
+ 0.8258520364761353,
+ 0.8292118906974792,
+ 0.8323183059692383,
+ 0.8278461694717407,
+ 0.8286098837852478,
+ 0.8372151255607605,
+ 0.8382918834686279,
+ 0.8395122289657593,
+ 0.840423047542572,
+ 0.8408432006835938,
+ 0.841245710849762,
+ 0.8415670990943909,
+ 0.8418585658073425,
+ 0.8411929607391357,
+ 0.8421667218208313,
+ 0.8424487709999084,
+ 0.8429855704307556,
+ 0.8435900211334229,
+ 0.8436681032180786,
+ 0.8439942002296448,
+ 0.8434712290763855,
+ 0.8450051546096802,
+ 0.8439319133758545,
+ 0.8445190787315369,
+ 0.8430773615837097,
+ 0.8428373336791992,
+ 0.8436524271965027,
+ 0.8429920673370361,
+ 0.8443013429641724,
+ 0.8443319797515869,
+ 0.8449892401695251,
+ 0.8376885056495667,
+ 0.841406524181366,
+ 0.8415302038192749,
+ 0.8408886194229126,
+ 0.8442723155021667,
+ 0.8440332412719727,
+ 0.8442908525466919,
+ 0.8443634510040283,
+ 0.8437660336494446,
+ 0.8392053842544556,
+ 0.8405532836914062,
+ 0.8407477736473083,
+ 0.8406219482421875,
+ 0.8408179879188538,
+ 0.8407355546951294,
+ 0.8411082625389099,
+ 0.8407227396965027,
+ 0.8417288661003113,
+ 0.8435289859771729,
+ 0.8406943082809448,
+ 0.8406220078468323,
+ 0.8430836200714111,
+ 0.8414878249168396,
+ 0.8407925367355347,
+ 0.8421801328659058,
+ 0.8422234654426575,
+ 0.8420581221580505,
+ 0.8420503735542297,
+ 0.8411982655525208,
+ 0.8417600989341736,
+ 0.8418623805046082,
+ 0.8419300317764282,
+ 0.8423025608062744,
+ 0.8427544236183167,
+ 0.8427135348320007,
+ 0.844169020652771,
+ 0.8436357378959656,
+ 0.8427347540855408,
+ 0.8422463536262512,
+ 0.8415619134902954,
+ 0.8414194583892822,
+ 0.8415847420692444,
+ 0.8405125737190247,
+ 0.8403295874595642,
+ 0.8391857147216797,
+ 0.839878261089325,
+ 0.8403366208076477,
+ 0.8406816124916077,
+ 0.8406360149383545,
+ 0.8400084376335144,
+ 0.8400857448577881,
+ 0.839959979057312,
+ 0.8401797413825989,
+ 0.8409728407859802,
+ 0.8402512669563293,
+ 0.8409667611122131,
+ 0.841234564781189,
+ 0.8381590247154236,
+ 0.8394162058830261,
+ 0.8405060172080994,
+ 0.8413972854614258,
+ 0.8416813015937805,
+ 0.8416779041290283,
+ 0.8404427170753479,
+ 0.8398978114128113,
+ 0.8398051857948303,
+ 0.8407351970672607,
+ 0.8412854671478271,
+ 0.8423576951026917,
+ 0.8464295268058777,
+ 0.8451360464096069,
+ 0.8451173305511475,
+ 0.8459922671318054,
+ 0.8468573093414307,
+ 0.846176266670227,
+ 0.8468344211578369,
+ 0.8463738560676575,
+ 0.8468202352523804,
+ 0.846664547920227,
+ 0.8477717638015747,
+ 0.8471964001655579,
+ 0.8456940054893494,
+ 0.8466700315475464,
+ 0.8462275266647339,
+ 0.8469456434249878
+ ],
+ "val_precision": [
+ 0.8711534738540649,
+ 0.9099778532981873,
+ 0.9012740850448608,
+ 0.8973432183265686,
+ 0.8958495259284973,
+ 0.8980125188827515,
+ 0.8952105045318604,
+ 0.8935657143592834,
+ 0.8896180391311646,
+ 0.8830705285072327,
+ 0.881388783454895,
+ 0.8785831332206726,
+ 0.8784195780754089,
+ 0.8662593960762024,
+ 0.8595232367515564,
+ 0.8465721011161804,
+ 0.8446353673934937,
+ 0.8311983942985535,
+ 0.8418576717376709,
+ 0.8404845595359802,
+ 0.8405564427375793,
+ 0.8373354077339172,
+ 0.8398115634918213,
+ 0.8381937742233276,
+ 0.8403121829032898,
+ 0.8403711915016174,
+ 0.8373908400535583,
+ 0.8363755941390991,
+ 0.8387280702590942,
+ 0.8359437584877014,
+ 0.8607511520385742,
+ 0.8508727550506592,
+ 0.8516408205032349,
+ 0.850869357585907,
+ 0.8463905453681946,
+ 0.9256880879402161,
+ 0.8811120986938477,
+ 0.8642243146896362,
+ 0.8511949777603149,
+ 0.8438611626625061,
+ 0.897867739200592,
+ 0.8976171612739563,
+ 0.9503747820854187,
+ 0.9013785123825073,
+ 0.8685796856880188,
+ 0.8691869378089905,
+ 0.8672569990158081,
+ 0.8460566997528076,
+ 0.8413045406341553,
+ 0.8369789123535156,
+ 0.8355416059494019,
+ 0.8512107133865356,
+ 0.846632182598114,
+ 0.8438904285430908,
+ 0.8412598967552185,
+ 0.8414592742919922,
+ 0.8418748378753662,
+ 0.8401152491569519,
+ 0.8390417695045471,
+ 0.8385841846466064,
+ 0.839232325553894,
+ 0.8377540707588196,
+ 0.834185779094696,
+ 0.8351815342903137,
+ 0.83882075548172,
+ 0.8391847014427185,
+ 0.8374903798103333,
+ 0.8391371369361877,
+ 0.8375192284584045,
+ 0.8376700282096863,
+ 0.8398131728172302,
+ 0.8305045366287231,
+ 0.8372330069541931,
+ 0.8399054408073425,
+ 0.8377699255943298,
+ 0.8375679850578308,
+ 0.835006594657898,
+ 0.8356190919876099,
+ 0.8395451903343201,
+ 0.8383533954620361,
+ 0.838898777961731,
+ 0.8382733464241028,
+ 0.8402997255325317,
+ 0.8408578038215637,
+ 0.8415656685829163,
+ 0.8414214253425598,
+ 0.840507447719574,
+ 0.8395826816558838,
+ 0.8433931469917297,
+ 0.8396790027618408,
+ 0.8400417566299438,
+ 0.8480983972549438,
+ 0.8443701267242432,
+ 0.8446289896965027,
+ 0.8388392925262451,
+ 0.8489418029785156,
+ 0.8449345827102661,
+ 0.8443923592567444,
+ 0.845302402973175,
+ 0.8450705409049988,
+ 0.9710460305213928,
+ 0.8663370609283447,
+ 0.8593899011611938,
+ 0.8566167950630188,
+ 0.859527051448822,
+ 0.8560407757759094,
+ 0.8620664477348328,
+ 0.8610326647758484,
+ 0.8568584322929382,
+ 0.8632041811943054,
+ 0.8479426503181458,
+ 0.8565797805786133,
+ 0.881535530090332,
+ 0.9172544479370117,
+ 0.9386345148086548,
+ 0.8827859163284302,
+ 0.9101639986038208,
+ 0.9348036646842957,
+ 0.9130666255950928,
+ 0.9183593392372131,
+ 0.9276469349861145,
+ 0.9261051416397095,
+ 0.9390121102333069,
+ 0.917735755443573,
+ 0.9261561036109924,
+ 0.9070261716842651,
+ 0.9143952131271362,
+ 0.9071436524391174,
+ 0.9126418232917786,
+ 0.8976956009864807,
+ 0.9085080623626709,
+ 0.905343770980835,
+ 0.8929307460784912,
+ 0.8945400714874268,
+ 0.8930835723876953,
+ 0.884345293045044,
+ 0.8809970617294312,
+ 0.8748594522476196,
+ 0.873525857925415,
+ 0.8670804500579834,
+ 0.8665019869804382,
+ 0.8446500301361084,
+ 0.8590936064720154,
+ 0.8657666444778442,
+ 0.8678518533706665,
+ 0.8619884848594666,
+ 0.8635320663452148,
+ 0.8665153384208679,
+ 0.8587769269943237,
+ 0.865226149559021,
+ 0.8698509931564331,
+ 0.871242344379425,
+ 0.8705009818077087,
+ 0.8737422823905945,
+ 0.8686668872833252,
+ 0.8597418069839478,
+ 0.8551826477050781,
+ 0.8492261171340942,
+ 0.8479729294776917,
+ 0.8436432480812073,
+ 0.8426222205162048,
+ 0.8409585952758789,
+ 0.8462572693824768,
+ 0.8505966067314148,
+ 0.8527435660362244,
+ 0.8455612063407898,
+ 0.8418802618980408,
+ 0.8454927802085876,
+ 0.8337236046791077,
+ 0.8402150869369507,
+ 0.8331714868545532,
+ 0.8320943713188171,
+ 0.8520482778549194,
+ 0.8709573745727539,
+ 0.8536549210548401,
+ 0.8313764333724976,
+ 0.8369816541671753,
+ 0.8376103043556213,
+ 0.8010546565055847,
+ 0.8396194577217102,
+ 0.8443663716316223,
+ 0.8486455082893372,
+ 0.8437392711639404,
+ 0.8403691053390503,
+ 0.8407142758369446,
+ 0.8399880528450012,
+ 0.839185893535614,
+ 0.8460862636566162,
+ 0.8545673489570618,
+ 0.860598623752594,
+ 0.8658726811408997,
+ 0.8679269552230835,
+ 0.866565465927124,
+ 0.8626867532730103,
+ 0.861561119556427,
+ 0.8615671396255493,
+ 0.8700774908065796,
+ 0.8644501566886902,
+ 0.8783872723579407,
+ 0.8526102900505066,
+ 0.8481480479240417,
+ 0.8520728349685669,
+ 0.850262463092804,
+ 0.8534269332885742,
+ 0.850530743598938,
+ 0.8601478934288025,
+ 0.8564969301223755,
+ 0.8558398485183716,
+ 0.8488183617591858,
+ 0.8448949456214905,
+ 0.8506718873977661,
+ 0.8446462154388428,
+ 0.8446815013885498,
+ 0.8393603563308716,
+ 0.8363129496574402,
+ 0.8319042921066284,
+ 0.8438207507133484,
+ 0.8403454422950745,
+ 0.8494036197662354,
+ 0.8482406139373779,
+ 0.8493703603744507,
+ 0.8521566390991211,
+ 0.8635560870170593,
+ 0.835563600063324,
+ 0.8510769009590149,
+ 0.8496324419975281,
+ 0.8478512167930603,
+ 0.847159206867218,
+ 0.8458375930786133,
+ 0.8478499054908752,
+ 0.8420311212539673,
+ 0.8471897840499878,
+ 0.8455916047096252,
+ 0.8447195887565613,
+ 0.8431907296180725,
+ 0.8468424081802368,
+ 0.8455039858818054,
+ 0.8472043871879578,
+ 0.8467671871185303,
+ 0.8406612873077393,
+ 0.8456061482429504,
+ 0.8399146199226379,
+ 0.8432310819625854,
+ 0.839999258518219,
+ 0.8430536985397339,
+ 0.841952919960022,
+ 0.8331800699234009,
+ 0.8475290536880493,
+ 0.8477332592010498,
+ 0.8592418432235718,
+ 0.8608124256134033,
+ 0.8471086025238037,
+ 0.8507931232452393,
+ 0.8570448160171509,
+ 0.8441699147224426,
+ 0.8461606502532959,
+ 0.8479183912277222,
+ 0.8481042981147766,
+ 0.8468204140663147,
+ 0.8429912328720093,
+ 0.8440841436386108,
+ 0.8464818596839905,
+ 0.8488946557044983,
+ 0.8782042264938354,
+ 0.8803180456161499,
+ 0.8736661672592163,
+ 0.8709073662757874,
+ 0.8666850328445435,
+ 0.8635985255241394,
+ 0.8619860410690308,
+ 0.8458244204521179,
+ 0.8539053797721863,
+ 0.8490270376205444,
+ 0.8503031134605408,
+ 0.8508244156837463,
+ 0.8424339890480042,
+ 0.8424623608589172,
+ 0.9407716393470764,
+ 0.9200730919837952,
+ 0.912413477897644,
+ 0.901079535484314,
+ 0.8914381861686707,
+ 0.8887494802474976,
+ 0.8808781504631042,
+ 0.8749341368675232,
+ 0.8733342289924622,
+ 0.8700673580169678,
+ 0.8675072193145752,
+ 0.8664731383323669,
+ 0.8562585115432739,
+ 0.8603118658065796,
+ 0.8580091595649719,
+ 0.8594563603401184,
+ 0.8582649230957031,
+ 0.8568836450576782,
+ 0.8514372706413269,
+ 0.8594683408737183,
+ 0.8582966923713684,
+ 0.8588802218437195,
+ 0.8594077825546265,
+ 0.856478750705719,
+ 0.8584680557250977,
+ 0.8564247488975525,
+ 0.8565948605537415,
+ 0.8511720895767212,
+ 0.8616523742675781,
+ 0.8520718216896057,
+ 0.8509661555290222,
+ 0.8458645939826965,
+ 0.8619470596313477,
+ 0.8367296457290649,
+ 0.8450754284858704,
+ 0.8439561724662781,
+ 0.8469855189323425,
+ 0.8472172617912292,
+ 0.8471127152442932,
+ 0.8471678495407104,
+ 0.8489171862602234,
+ 0.8484451770782471,
+ 0.8490914702415466,
+ 0.849514901638031,
+ 0.8497565984725952,
+ 0.848747193813324,
+ 0.847416341304779,
+ 0.8483163118362427,
+ 0.8474992513656616,
+ 0.8459259867668152,
+ 0.8463056683540344,
+ 0.8468024134635925,
+ 0.8468325138092041,
+ 0.8456897735595703,
+ 0.8450791239738464,
+ 0.8440936803817749,
+ 0.830335795879364,
+ 0.7993900179862976,
+ 0.793070912361145,
+ 0.8034380078315735,
+ 0.8194118738174438,
+ 0.8142642378807068,
+ 0.8157716393470764,
+ 0.8192340135574341,
+ 0.8195520043373108,
+ 0.8228065371513367,
+ 0.8221330642700195,
+ 0.8225405812263489,
+ 0.8252206444740295,
+ 0.8260633945465088,
+ 0.825140118598938,
+ 0.8256429433822632,
+ 0.8282857537269592,
+ 0.8261280655860901,
+ 0.8279370665550232,
+ 0.8756867051124573,
+ 0.8302189111709595,
+ 0.8211290836334229,
+ 0.8243001103401184,
+ 0.8352564573287964,
+ 0.8340696096420288,
+ 0.8324414491653442,
+ 0.8342196345329285,
+ 0.8347622752189636,
+ 0.8345479965209961,
+ 0.8326995968818665,
+ 0.8201993107795715,
+ 0.8334702849388123,
+ 0.8323670625686646,
+ 0.8328129649162292,
+ 0.8324663639068604,
+ 0.8275765776634216,
+ 0.8328525424003601,
+ 0.8338319063186646,
+ 0.8406358957290649,
+ 0.8342118859291077,
+ 0.8341279625892639,
+ 0.844491720199585,
+ 0.8484975099563599,
+ 0.8435754179954529,
+ 0.8398926854133606,
+ 0.8345503211021423,
+ 0.839335560798645,
+ 0.8277166485786438,
+ 0.8274762630462646,
+ 0.8282976746559143,
+ 0.8294353485107422,
+ 0.8292824625968933,
+ 0.8328487873077393,
+ 0.8292620778083801,
+ 0.8316678404808044,
+ 0.836163341999054,
+ 0.8351377844810486,
+ 0.8346888422966003,
+ 0.8311458230018616,
+ 0.838955283164978,
+ 0.8381076455116272,
+ 0.8383089900016785,
+ 0.8354285359382629,
+ 0.838118851184845,
+ 0.8365051746368408,
+ 0.8356513977050781,
+ 0.8323842287063599,
+ 0.8316612243652344,
+ 0.8370093107223511,
+ 0.8296104073524475,
+ 0.8354123830795288,
+ 0.8300120234489441,
+ 0.8340489864349365,
+ 0.8298304080963135,
+ 0.8299878239631653,
+ 0.8266295790672302,
+ 0.8280067443847656,
+ 0.8353079557418823,
+ 0.8343426585197449,
+ 0.8352524638175964,
+ 0.835371196269989,
+ 0.8479501605033875,
+ 0.8313549160957336,
+ 0.8351329565048218,
+ 0.836721658706665,
+ 0.8332098126411438,
+ 0.8351552486419678,
+ 0.83438640832901,
+ 0.8342813849449158,
+ 0.8330809473991394,
+ 0.8357232809066772,
+ 0.8361101746559143,
+ 0.8338901996612549,
+ 0.8371801376342773,
+ 0.8381467461585999,
+ 0.8457373380661011,
+ 0.8443560004234314,
+ 0.8427362442016602,
+ 0.843136191368103,
+ 0.8403492569923401,
+ 0.8423163890838623,
+ 0.8420248627662659,
+ 0.840101957321167,
+ 0.8397964835166931,
+ 0.8396996259689331,
+ 0.8402324914932251,
+ 0.8376874923706055,
+ 0.8397992849349976,
+ 0.8312510848045349,
+ 0.8357895016670227,
+ 0.8348694443702698,
+ 0.8337839841842651,
+ 0.8347054123878479,
+ 0.832122802734375,
+ 0.8326663374900818,
+ 0.8303632140159607,
+ 0.8294076919555664,
+ 0.8289780020713806,
+ 0.8320909142494202,
+ 0.8328109383583069,
+ 0.8335252404212952,
+ 0.832923173904419,
+ 0.8333304524421692,
+ 0.8312778472900391,
+ 0.8306232690811157,
+ 0.8294838666915894,
+ 0.8314119577407837,
+ 0.8328531980514526,
+ 0.836654543876648,
+ 0.8344799876213074,
+ 0.8263624906539917,
+ 0.8262391686439514,
+ 0.8288564085960388,
+ 0.8312889337539673,
+ 0.832801103591919,
+ 0.8319774270057678,
+ 0.8286105990409851,
+ 0.8272514939308167,
+ 0.8288088440895081,
+ 0.8281652927398682,
+ 0.828852653503418,
+ 0.830596387386322,
+ 0.8328118920326233,
+ 0.8277364373207092,
+ 0.8315495848655701,
+ 0.8332411050796509,
+ 0.8352165222167969,
+ 0.8314898014068604,
+ 0.8340672850608826,
+ 0.8345930576324463,
+ 0.835718035697937,
+ 0.837015688419342,
+ 0.8390412926673889,
+ 0.8377916216850281,
+ 0.8331015110015869,
+ 0.8361600041389465,
+ 0.8335365653038025,
+ 0.8368273377418518
+ ],
+ "val_recall": [
+ 0.7138949036598206,
+ 0.6286648511886597,
+ 0.6385065317153931,
+ 0.6445983648300171,
+ 0.6512171030044556,
+ 0.6463722586631775,
+ 0.6573383212089539,
+ 0.666047215461731,
+ 0.6832208037376404,
+ 0.7122219800949097,
+ 0.7173653841018677,
+ 0.7347973585128784,
+ 0.7411508560180664,
+ 0.7763452529907227,
+ 0.7962082624435425,
+ 0.8206493258476257,
+ 0.8289700150489807,
+ 0.8276956081390381,
+ 0.8290759325027466,
+ 0.8334759473800659,
+ 0.8357791900634766,
+ 0.8409918546676636,
+ 0.8399480581283569,
+ 0.844188928604126,
+ 0.8431087136268616,
+ 0.843708872795105,
+ 0.848979651927948,
+ 0.851351797580719,
+ 0.8493645191192627,
+ 0.8524963855743408,
+ 0.7727329730987549,
+ 0.8133203983306885,
+ 0.8196693062782288,
+ 0.8249475955963135,
+ 0.8286963701248169,
+ 0.5666696429252625,
+ 0.7584494352340698,
+ 0.7947556972503662,
+ 0.8203129172325134,
+ 0.8304421901702881,
+ 0.1745501309633255,
+ 0.17224669456481934,
+ 0.42887505888938904,
+ 0.6848285794258118,
+ 0.7948252558708191,
+ 0.7959211468696594,
+ 0.799809992313385,
+ 0.8348971009254456,
+ 0.842459499835968,
+ 0.8518373966217041,
+ 0.8531920313835144,
+ 0.8106057047843933,
+ 0.8251631259918213,
+ 0.8344990015029907,
+ 0.8401424884796143,
+ 0.8430490493774414,
+ 0.844461977481842,
+ 0.8462569117546082,
+ 0.8492566347122192,
+ 0.8510982394218445,
+ 0.851049542427063,
+ 0.8556559681892395,
+ 0.8597413897514343,
+ 0.8624208569526672,
+ 0.8590958118438721,
+ 0.857639491558075,
+ 0.8579022884368896,
+ 0.8555306196212769,
+ 0.8565317988395691,
+ 0.8590701818466187,
+ 0.8552781939506531,
+ 0.863028347492218,
+ 0.8566315174102783,
+ 0.8506564497947693,
+ 0.8543311953544617,
+ 0.8545897006988525,
+ 0.8585560321807861,
+ 0.8554390072822571,
+ 0.853297233581543,
+ 0.853861927986145,
+ 0.8533856272697449,
+ 0.8548802733421326,
+ 0.8526723384857178,
+ 0.8513028621673584,
+ 0.8510808348655701,
+ 0.8495150804519653,
+ 0.8507934808731079,
+ 0.8524266481399536,
+ 0.8483285307884216,
+ 0.8525064587593079,
+ 0.8531935811042786,
+ 0.8452684879302979,
+ 0.8489734530448914,
+ 0.8487175703048706,
+ 0.8545060753822327,
+ 0.8466358780860901,
+ 0.8503495454788208,
+ 0.8516244292259216,
+ 0.8498554229736328,
+ 0.850494384765625,
+ 0.44720861315727234,
+ 0.8235508799552917,
+ 0.8393977284431458,
+ 0.8446410894393921,
+ 0.8418731093406677,
+ 0.8467342257499695,
+ 0.8340064883232117,
+ 0.832314670085907,
+ 0.8308895826339722,
+ 0.8281630873680115,
+ 0.854369044303894,
+ 0.8318706750869751,
+ 0.7502104043960571,
+ 0.6147958040237427,
+ 0.5343802571296692,
+ 0.7403198480606079,
+ 0.6522589325904846,
+ 0.5624599456787109,
+ 0.6424561738967896,
+ 0.6245296597480774,
+ 0.5936372876167297,
+ 0.60028475522995,
+ 0.557341992855072,
+ 0.6365642547607422,
+ 0.6096900701522827,
+ 0.6775719523429871,
+ 0.6531724333763123,
+ 0.6783187985420227,
+ 0.6626167297363281,
+ 0.7168574929237366,
+ 0.6787083745002747,
+ 0.6903501749038696,
+ 0.7266693711280823,
+ 0.7233291864395142,
+ 0.7288182973861694,
+ 0.7511191964149475,
+ 0.7684029340744019,
+ 0.7821173071861267,
+ 0.7917649745941162,
+ 0.8049054741859436,
+ 0.8070986270904541,
+ 0.8550939559936523,
+ 0.8245589733123779,
+ 0.8108031749725342,
+ 0.805087149143219,
+ 0.8195995688438416,
+ 0.8167506456375122,
+ 0.8159922957420349,
+ 0.8290852904319763,
+ 0.8105812072753906,
+ 0.8037106394767761,
+ 0.7992058992385864,
+ 0.8014134168624878,
+ 0.7948223352432251,
+ 0.8062489628791809,
+ 0.828700065612793,
+ 0.8370940089225769,
+ 0.8341236114501953,
+ 0.8411271572113037,
+ 0.8495692610740662,
+ 0.8506975769996643,
+ 0.8540267944335938,
+ 0.848400354385376,
+ 0.8412774205207825,
+ 0.8369691371917725,
+ 0.8473140597343445,
+ 0.8577393889427185,
+ 0.8619421124458313,
+ 0.8698633313179016,
+ 0.8683803677558899,
+ 0.8725437521934509,
+ 0.8747128844261169,
+ 0.8498877882957458,
+ 0.8027667999267578,
+ 0.8402888774871826,
+ 0.8779536485671997,
+ 0.8747881054878235,
+ 0.8732026219367981,
+ 0.9028149843215942,
+ 0.888365626335144,
+ 0.882239043712616,
+ 0.8798195719718933,
+ 0.8843492269515991,
+ 0.8823739886283875,
+ 0.8820407390594482,
+ 0.8825317621231079,
+ 0.8807196021080017,
+ 0.8673355579376221,
+ 0.8505632877349854,
+ 0.8365170955657959,
+ 0.8258256316184998,
+ 0.8226954340934753,
+ 0.8244332075119019,
+ 0.8335296511650085,
+ 0.8384708166122437,
+ 0.837757408618927,
+ 0.8183615803718567,
+ 0.8272135257720947,
+ 0.7957266569137573,
+ 0.8523391485214233,
+ 0.8591209053993225,
+ 0.8521117568016052,
+ 0.8563898205757141,
+ 0.8488904237747192,
+ 0.8548893332481384,
+ 0.8391090631484985,
+ 0.8465986847877502,
+ 0.8486119508743286,
+ 0.860098659992218,
+ 0.8688785433769226,
+ 0.8623787760734558,
+ 0.8664694428443909,
+ 0.864578366279602,
+ 0.8683936595916748,
+ 0.8762837648391724,
+ 0.8789510130882263,
+ 0.8645862340927124,
+ 0.8669252991676331,
+ 0.8612207174301147,
+ 0.8644996285438538,
+ 0.8641279339790344,
+ 0.8632484674453735,
+ 0.850382387638092,
+ 0.84476238489151,
+ 0.838936984539032,
+ 0.8458986878395081,
+ 0.8510852456092834,
+ 0.8555907011032104,
+ 0.858672022819519,
+ 0.8590560555458069,
+ 0.8654761910438538,
+ 0.8610984683036804,
+ 0.8630610108375549,
+ 0.8639771938323975,
+ 0.8652169108390808,
+ 0.8618441224098206,
+ 0.8632146716117859,
+ 0.861068069934845,
+ 0.8609952330589294,
+ 0.8652485609054565,
+ 0.862664520740509,
+ 0.86861652135849,
+ 0.8662468194961548,
+ 0.8682829737663269,
+ 0.8652265667915344,
+ 0.8663906455039978,
+ 0.8764907121658325,
+ 0.8683980703353882,
+ 0.8681570887565613,
+ 0.8505733609199524,
+ 0.8458050489425659,
+ 0.8658043742179871,
+ 0.8610346913337708,
+ 0.8521830439567566,
+ 0.8704118728637695,
+ 0.8671225905418396,
+ 0.866253674030304,
+ 0.8651615381240845,
+ 0.8661726713180542,
+ 0.8627281188964844,
+ 0.8668228387832642,
+ 0.8665737509727478,
+ 0.8614661693572998,
+ 0.8398611545562744,
+ 0.8382641077041626,
+ 0.8480514287948608,
+ 0.8521321415901184,
+ 0.8562690019607544,
+ 0.8573816418647766,
+ 0.8554061055183411,
+ 0.8768815398216248,
+ 0.8713105320930481,
+ 0.8744873404502869,
+ 0.8742911219596863,
+ 0.8740506768226624,
+ 0.8786142468452454,
+ 0.8788845539093018,
+ 0.7281445264816284,
+ 0.7792198657989502,
+ 0.7952057123184204,
+ 0.8151026964187622,
+ 0.8281133770942688,
+ 0.8318414092063904,
+ 0.8414968848228455,
+ 0.8494554758071899,
+ 0.8523086905479431,
+ 0.8567559123039246,
+ 0.8594571948051453,
+ 0.8608883619308472,
+ 0.8712697625160217,
+ 0.8679503202438354,
+ 0.8685165047645569,
+ 0.8658415675163269,
+ 0.8650411367416382,
+ 0.8667349219322205,
+ 0.8720259070396423,
+ 0.8650246858596802,
+ 0.8663941025733948,
+ 0.8645951151847839,
+ 0.8638908863067627,
+ 0.8668110966682434,
+ 0.865586519241333,
+ 0.866898775100708,
+ 0.8665620684623718,
+ 0.8728160262107849,
+ 0.8631929755210876,
+ 0.8721675872802734,
+ 0.8678584694862366,
+ 0.8771296143531799,
+ 0.8559105396270752,
+ 0.878553569316864,
+ 0.8760054707527161,
+ 0.8771207928657532,
+ 0.8741464614868164,
+ 0.8744163513183594,
+ 0.8740748763084412,
+ 0.8720303177833557,
+ 0.8714028596878052,
+ 0.8722689747810364,
+ 0.872020959854126,
+ 0.8710580468177795,
+ 0.8712208271026611,
+ 0.8726319670677185,
+ 0.8732669353485107,
+ 0.872332751750946,
+ 0.8741260170936584,
+ 0.8732951879501343,
+ 0.8740338683128357,
+ 0.8736048340797424,
+ 0.8737559914588928,
+ 0.8744267821311951,
+ 0.8729435205459595,
+ 0.8749172687530518,
+ 0.8717195987701416,
+ 0.8865574598312378,
+ 0.89252108335495,
+ 0.8859567642211914,
+ 0.8748524785041809,
+ 0.8796765804290771,
+ 0.8795775175094604,
+ 0.8781686425209045,
+ 0.8796904683113098,
+ 0.8776493072509766,
+ 0.8786180019378662,
+ 0.8797085285186768,
+ 0.8780279159545898,
+ 0.8777487874031067,
+ 0.8793208003044128,
+ 0.8785813450813293,
+ 0.8751891851425171,
+ 0.8794978260993958,
+ 0.8776888251304626,
+ 0.7911058068275452,
+ 0.8311010003089905,
+ 0.8563839793205261,
+ 0.8567081689834595,
+ 0.8443841934204102,
+ 0.8506578803062439,
+ 0.8573545217514038,
+ 0.8574348092079163,
+ 0.8590723276138306,
+ 0.8612195253372192,
+ 0.86473548412323,
+ 0.8760099411010742,
+ 0.8660531640052795,
+ 0.8679808378219604,
+ 0.8673582673072815,
+ 0.8687211871147156,
+ 0.871544599533081,
+ 0.8671865463256836,
+ 0.8673295378684998,
+ 0.8607675433158875,
+ 0.8682910799980164,
+ 0.8698164820671082,
+ 0.8410481810569763,
+ 0.8346937298774719,
+ 0.845123291015625,
+ 0.8533225655555725,
+ 0.8528366088867188,
+ 0.8489379286766052,
+ 0.8733007311820984,
+ 0.8751024007797241,
+ 0.8758694529533386,
+ 0.8759905695915222,
+ 0.8767417669296265,
+ 0.8735772371292114,
+ 0.8776817917823792,
+ 0.8754622340202332,
+ 0.8693927526473999,
+ 0.8716742396354675,
+ 0.8722932934761047,
+ 0.8772057890892029,
+ 0.8697801232337952,
+ 0.8707911968231201,
+ 0.8711106777191162,
+ 0.8733454942703247,
+ 0.872536838054657,
+ 0.872806191444397,
+ 0.8746081590652466,
+ 0.8756340742111206,
+ 0.8766515254974365,
+ 0.8718135356903076,
+ 0.8793970346450806,
+ 0.8753090500831604,
+ 0.8807417750358582,
+ 0.8770509958267212,
+ 0.8700686693191528,
+ 0.8759216666221619,
+ 0.879711925983429,
+ 0.8776720762252808,
+ 0.8750447034835815,
+ 0.8758551478385925,
+ 0.8751562833786011,
+ 0.8752334713935852,
+ 0.8557966351509094,
+ 0.8664177060127258,
+ 0.8659477829933167,
+ 0.8655065298080444,
+ 0.8695750832557678,
+ 0.8680756688117981,
+ 0.8689611554145813,
+ 0.8700358867645264,
+ 0.8705236315727234,
+ 0.8693567514419556,
+ 0.8715224862098694,
+ 0.8695716857910156,
+ 0.8662028908729553,
+ 0.8606469631195068,
+ 0.8523414134979248,
+ 0.8541622757911682,
+ 0.8591538667678833,
+ 0.8598201870918274,
+ 0.8630911707878113,
+ 0.8613525629043579,
+ 0.8603708148002625,
+ 0.8637220859527588,
+ 0.8644904494285583,
+ 0.8649963736534119,
+ 0.8652148842811584,
+ 0.868453860282898,
+ 0.8662175536155701,
+ 0.8789054751396179,
+ 0.8731442093849182,
+ 0.8728610277175903,
+ 0.8733997344970703,
+ 0.871367335319519,
+ 0.8739941716194153,
+ 0.873519241809845,
+ 0.8734442591667175,
+ 0.8747082948684692,
+ 0.8735887408256531,
+ 0.8713658452033997,
+ 0.8713515400886536,
+ 0.8709768652915955,
+ 0.8718175888061523,
+ 0.8706104159355164,
+ 0.8728729486465454,
+ 0.8735764622688293,
+ 0.8750870227813721,
+ 0.8737728595733643,
+ 0.8718305230140686,
+ 0.8687755465507507,
+ 0.8713995218276978,
+ 0.8761162161827087,
+ 0.8786718249320984,
+ 0.8769229650497437,
+ 0.8760003447532654,
+ 0.8746944665908813,
+ 0.8755800127983093,
+ 0.8767211437225342,
+ 0.8771851062774658,
+ 0.8749100565910339,
+ 0.8772239685058594,
+ 0.8775946497917175,
+ 0.8773806095123291,
+ 0.8823702931404114,
+ 0.8854650855064392,
+ 0.8822610378265381,
+ 0.8820188045501709,
+ 0.881903886795044,
+ 0.8842942118644714,
+ 0.8824272751808167,
+ 0.8809127807617188,
+ 0.8803654313087463,
+ 0.87862229347229,
+ 0.8781221508979797,
+ 0.8786019682884216,
+ 0.8810625076293945,
+ 0.8792002201080322,
+ 0.8812903761863708,
+ 0.8788249492645264
+ ]
+}
\ No newline at end of file
diff --git a/training_history/2025-08-07_16-25-27.png b/training_history/2025-08-07_16-25-27.png
new file mode 100644
index 0000000000000000000000000000000000000000..74137d3865c80baa78c417b5b7928bba8e5f5728
Binary files /dev/null and b/training_history/2025-08-07_16-25-27.png differ
diff --git a/utils/BilinearUpSampling.py b/utils/BilinearUpSampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e99327c7e97bb980280a123217097ddb2b9d99e1
--- /dev/null
+++ b/utils/BilinearUpSampling.py
@@ -0,0 +1,92 @@
+import keras.backend as K
+import tensorflow as tf
+from keras.layers import *
+
+def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'):
+ '''Resizes the images contained in a 4D tensor of shape
+ - [batch, channels, height, width] (for 'channels_first' data_format)
+ - [batch, height, width, channels] (for 'channels_last' data_format)
+ by a factor of (height_factor, width_factor). Both factors should be
+ positive integers.
+ '''
+ if data_format == 'default':
+ data_format = K.image_data_format()
+ if data_format == 'channels_first':
+ original_shape = K.int_shape(X)
+ if target_height and target_width:
+ new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
+ else:
+ new_shape = tf.shape(X)[2:]
+ new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
+ X = K.permute_dimensions(X, [0, 2, 3, 1])
+ X = tf.image.resize_bilinear(X, new_shape)
+ X = K.permute_dimensions(X, [0, 3, 1, 2])
+ if target_height and target_width:
+ X.set_shape((None, None, target_height, target_width))
+ else:
+ X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor))
+ return X
+ elif data_format == 'channels_last':
+ original_shape = K.int_shape(X)
+ if target_height and target_width:
+ new_shape = tf.constant(np.array((target_height, target_width)).astype('int32'))
+ else:
+ new_shape = tf.shape(X)[1:3]
+ new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32'))
+ X = tf.image.resize_bilinear(X, new_shape)
+ if target_height and target_width:
+ X.set_shape((None, target_height, target_width, None))
+ else:
+ X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None))
+ return X
+ else:
+ raise Exception('Invalid data_format: ' + data_format)
+
+class BilinearUpSampling2D(Layer):
+ def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs):
+ if data_format == 'default':
+ data_format = K.image_data_format()
+ self.size = tuple(size)
+ if target_size is not None:
+ self.target_size = tuple(target_size)
+ else:
+ self.target_size = None
+ assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}'
+ self.data_format = data_format
+ self.input_spec = [InputSpec(ndim=4)]
+ super(BilinearUpSampling2D, self).__init__(**kwargs)
+
+ def compute_output_shape(self, input_shape):
+ if self.data_format == 'channels_first':
+ width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None)
+ height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None)
+ if self.target_size is not None:
+ width = self.target_size[0]
+ height = self.target_size[1]
+ return (input_shape[0],
+ input_shape[1],
+ width,
+ height)
+ elif self.data_format == 'channels_last':
+ width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None)
+ height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None)
+ if self.target_size is not None:
+ width = self.target_size[0]
+ height = self.target_size[1]
+ return (input_shape[0],
+ width,
+ height,
+ input_shape[3])
+ else:
+ raise Exception('Invalid data_format: ' + self.data_format)
+
+ def call(self, x, mask=None):
+ if self.target_size is not None:
+ return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format)
+ else:
+ return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format)
+
+ def get_config(self):
+ config = {'size': self.size, 'target_size': self.target_size}
+ base_config = super(BilinearUpSampling2D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/__pycache__/BilinearUpSampling.cpython-37.pyc b/utils/__pycache__/BilinearUpSampling.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67538a217261825cb683a40f07abbdfd92bad3d2
Binary files /dev/null and b/utils/__pycache__/BilinearUpSampling.cpython-37.pyc differ
diff --git a/utils/__pycache__/BilinearUpSampling.cpython-39.pyc b/utils/__pycache__/BilinearUpSampling.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6570fc233365760d6e8c58b09d20c1b1050042f0
Binary files /dev/null and b/utils/__pycache__/BilinearUpSampling.cpython-39.pyc differ
diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57309b319954cdf798275fc2f67cd4538c76f09e
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/utils/__pycache__/__init__.cpython-313.pyc b/utils/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3badae41d37fc05b2ec4a04a442126edaa19f19
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-313.pyc differ
diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28c3bd61a2b0b755370d18ec9f20e548ab48a549
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-37.pyc differ
diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c7559cffb728d4799b33fad81d64f1ecb8d2d02
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/augment.py b/utils/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b65ad8131f20387b7510e43fa4467573860860f
--- /dev/null
+++ b/utils/augment.py
@@ -0,0 +1,10 @@
+import Augmentor
+
+p = Augmentor.Pipeline("./data/train/images")
+p.ground_truth("./data/train/labels")
+p.rotate(probability=0.7, max_left_rotation=25, max_right_rotation=25)
+p.flip_left_right(probability=0.5)
+p.zoom_random(probability=0.5, percentage_area=0.8)
+p.flip_top_bottom(probability=0.5)
+p.set_save_format(save_format='auto')
+p.sample(10000, multi_threaded=False)
\ No newline at end of file
diff --git a/utils/config/memory.py b/utils/config/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcfcab3fe6e62bd6c5d353fbfa603ae00036edcc
--- /dev/null
+++ b/utils/config/memory.py
@@ -0,0 +1,33 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/config/read.py
+# author : ZFTurbo
+# Calculate the memory needed to run the model
+#
+# ------------------------------------------------------------ #
+
+def get_model_memory_usage(batch_size, model):
+ import numpy as np
+ from keras import backend as K
+
+ shapes_mem_count = 0
+ for l in model.layers:
+ single_layer_mem = 1
+ for s in l.output_shape:
+ if s is None:
+ continue
+ single_layer_mem *= s
+ shapes_mem_count += single_layer_mem
+
+ trainable_count = np.sum([K.count_params(p) for p in set(model.trainable_weights)])
+ non_trainable_count = np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])
+
+ number_size = 4.0
+ if K.floatx() == 'float16':
+ number_size = 2.0
+ if K.floatx() == 'float64':
+ number_size = 8.0
+
+ total_memory = number_size*(batch_size*shapes_mem_count + trainable_count + non_trainable_count)
+ gbytes = np.round(total_memory / (1024.0 ** 3), 3)
+ return gbytes
\ No newline at end of file
diff --git a/utils/config/read.py b/utils/config/read.py
new file mode 100644
index 0000000000000000000000000000000000000000..21166b12e430244516f2ed4e63cc9d76513c6603
--- /dev/null
+++ b/utils/config/read.py
@@ -0,0 +1,103 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/config/read.py
+# author : CM
+# Read the configuration
+#
+# ------------------------------------------------------------ #
+
+import configparser
+
+
+def readConfig(filename):
+
+ # ----- Read the configuration ----
+ config = configparser.RawConfigParser()
+ config.read_file(open(filename))
+
+ dataset_in_path = config.get("dataset", "in_path")
+ dataset_gd_path = config.get("dataset", "gd_path")
+
+ dataset_train = int(config.get("dataset", "train"))
+ dataset_valid = int(config.get("dataset", "valid"))
+ dataset_test = int(config.get("dataset", "test"))
+
+
+ train_patch_size_x = int(config.get("train", "patch_size_x"))
+ train_patch_size_y = int(config.get("train", "patch_size_y"))
+ train_patch_size_z = int(config.get("train", "patch_size_z"))
+
+ train_batch_size = int(config.get("train", "batch_size"))
+ train_steps_per_epoch = int(config.get("train", "steps_per_epoch"))
+ train_epochs = int(config.get("train", "epochs"))
+
+ logs_path = config.get("train", "logs_path")
+
+
+ return {"dataset_in_path": dataset_in_path,
+ "dataset_gd_path": dataset_gd_path,
+ "dataset_train": dataset_train,
+ "dataset_valid": dataset_valid,
+ "dataset_test": dataset_test,
+ "train_patch_size_x": train_patch_size_x,
+ "train_patch_size_y": train_patch_size_y,
+ "train_patch_size_z": train_patch_size_z,
+ "train_batch_size": train_batch_size,
+ "train_steps_per_epoch": train_steps_per_epoch,
+ "train_epochs": train_epochs,
+ "logs_path": logs_path
+ }
+
+# Old version will be deleted soon
+def readConfig_OLD(filename):
+
+ # ----- Read the configuration ----
+ config = configparser.RawConfigParser()
+ config.read_file(open(filename))
+
+ dataset_train_size = int(config.get("dataset","train_size"))
+ dataset_train_gd_path = config.get("dataset","train_gd_path")
+ dataset_train_mra_path = config.get("dataset","train_mra_path")
+
+ dataset_valid_size = int(config.get("dataset","valid_size"))
+ dataset_valid_gd_path = config.get("dataset","valid_gd_path")
+ dataset_valid_mra_path = config.get("dataset","valid_mra_path")
+
+ dataset_test_size = int(config.get("dataset","test_size"))
+ dataset_test_gd_path = config.get("dataset","test_gd_path")
+ dataset_test_mra_path = config.get("dataset","test_mra_path")
+
+ image_size_x = int(config.get("data","image_size_x"))
+ image_size_y = int(config.get("data","image_size_y"))
+ image_size_z = int(config.get("data","image_size_z"))
+
+ patch_size_x = int(config.get("patchs","patch_size_x"))
+ patch_size_y = int(config.get("patchs","patch_size_y"))
+ patch_size_z = int(config.get("patchs","patch_size_z"))
+
+ batch_size = int(config.get("train","batch_size"))
+ steps_per_epoch = int(config.get("train","steps_per_epoch"))
+ epochs = int(config.get("train","epochs"))
+
+ logs_folder = config.get("logs","folder")
+
+ return {"dataset_train_size" : dataset_train_size,
+ "dataset_train_gd_path" : dataset_train_gd_path,
+ "dataset_train_mra_path": dataset_train_mra_path,
+ "dataset_valid_size" : dataset_valid_size,
+ "dataset_valid_gd_path" : dataset_valid_gd_path,
+ "dataset_valid_mra_path": dataset_valid_mra_path,
+ "dataset_test_size" : dataset_test_size,
+ "dataset_test_gd_path" : dataset_test_gd_path,
+ "dataset_test_mra_path" : dataset_test_mra_path,
+ "image_size_x" : image_size_x,
+ "image_size_y" : image_size_y,
+ "image_size_z" : image_size_z,
+ "patch_size_x" : patch_size_x,
+ "patch_size_y" : patch_size_y,
+ "patch_size_z" : patch_size_z,
+ "batch_size" : batch_size,
+ "steps_per_epoch" : steps_per_epoch,
+ "epochs" : epochs,
+ "logs_folder" : logs_folder
+ }
\ No newline at end of file
diff --git a/utils/io/__init__.py b/utils/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/io/__pycache__/__init__.cpython-310.pyc b/utils/io/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fe2f003fe04e93022cb986c8caff71c37931a24
Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-310.pyc differ
diff --git a/utils/io/__pycache__/__init__.cpython-313.pyc b/utils/io/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06307faf7974df0d75ff15439458146f9daca008
Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-313.pyc differ
diff --git a/utils/io/__pycache__/__init__.cpython-37.pyc b/utils/io/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a30f6120b4a96d1fdf7d2affeedbcc6385031f3
Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-37.pyc differ
diff --git a/utils/io/__pycache__/__init__.cpython-39.pyc b/utils/io/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f00aefe73934dfde844e2616b61a3720014f41ea
Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/io/__pycache__/data.cpython-310.pyc b/utils/io/__pycache__/data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1741d3fd554a4ec3084e9f1e4a46fce5863dc3b1
Binary files /dev/null and b/utils/io/__pycache__/data.cpython-310.pyc differ
diff --git a/utils/io/__pycache__/data.cpython-313.pyc b/utils/io/__pycache__/data.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48748fb92c1c6039ce9522f8d761d5e3c031caca
Binary files /dev/null and b/utils/io/__pycache__/data.cpython-313.pyc differ
diff --git a/utils/io/__pycache__/data.cpython-37.pyc b/utils/io/__pycache__/data.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acc8ed5077d79a1646c9937a52e8ce60f1fe302b
Binary files /dev/null and b/utils/io/__pycache__/data.cpython-37.pyc differ
diff --git a/utils/io/__pycache__/data.cpython-39.pyc b/utils/io/__pycache__/data.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d6ec35cca0c9521ae45c74ad4683a0dcdd857e6
Binary files /dev/null and b/utils/io/__pycache__/data.cpython-39.pyc differ
diff --git a/utils/io/data.py b/utils/io/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..7845d12a8ceb0826dc01ac8c34b6460cc803cb3e
--- /dev/null
+++ b/utils/io/data.py
@@ -0,0 +1,238 @@
+import os
+import cv2
+import json
+import random
+import datetime
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+class DataGen:
+
+ def __init__(self, path, split_ratio, x, y, color_space='rgb'):
+ self.x = x
+ self.y = y
+ self.path = path
+ self.color_space = color_space
+ self.path_train_images = path + "train/images/"
+ self.path_train_labels = path + "train/labels/"
+ self.path_test_images = path + "test/images/"
+ self.path_test_labels = path + "test/labels/"
+ self.image_file_list = get_png_filename_list(self.path_train_images)
+ self.label_file_list = get_png_filename_list(self.path_train_labels)
+ self.image_file_list[:], self.label_file_list[:] = self.shuffle_image_label_lists_together()
+ self.split_index = int(split_ratio * len(self.image_file_list))
+ self.x_train_file_list = self.image_file_list[self.split_index:]
+ self.y_train_file_list = self.label_file_list[self.split_index:]
+ self.x_val_file_list = self.image_file_list[:self.split_index]
+ self.y_val_file_list = self.label_file_list[:self.split_index]
+ self.x_test_file_list = get_png_filename_list(self.path_test_images)
+ self.y_test_file_list = get_png_filename_list(self.path_test_labels)
+
+ def generate_data(self, batch_size, train=False, val=False, test=False):
+ """Replaces Keras' native ImageDataGenerator."""
+ try:
+ if train is True:
+ image_file_list = self.x_train_file_list
+ label_file_list = self.y_train_file_list
+ elif val is True:
+ image_file_list = self.x_val_file_list
+ label_file_list = self.y_val_file_list
+ elif test is True:
+ image_file_list = self.x_test_file_list
+ label_file_list = self.y_test_file_list
+ except ValueError:
+ print('one of train or val or test need to be True')
+
+ i = 0
+ while True:
+ image_batch = []
+ label_batch = []
+ for b in range(batch_size):
+ if i == len(self.x_train_file_list):
+ i = 0
+ if i < len(image_file_list):
+ sample_image_filename = image_file_list[i]
+ sample_label_filename = label_file_list[i]
+ # print('image: ', image_file_list[i])
+ # print('label: ', label_file_list[i])
+ if train or val:
+ image = cv2.imread(self.path_train_images + sample_image_filename, 1)
+ label = cv2.imread(self.path_train_labels + sample_label_filename, 0)
+ elif test is True:
+ image = cv2.imread(self.path_test_images + sample_image_filename, 1)
+ label = cv2.imread(self.path_test_labels + sample_label_filename, 0)
+ # image, label = self.change_color_space(image, label, self.color_space)
+ label = np.expand_dims(label, axis=2)
+ if image.shape[0] == self.x and image.shape[1] == self.y:
+ image_batch.append(image.astype("float32"))
+ else:
+ print('the input image shape is not {}x{}'.format(self.x, self.y))
+ if label.shape[0] == self.x and label.shape[1] == self.y:
+ label_batch.append(label.astype("float32"))
+ else:
+ print('the input label shape is not {}x{}'.format(self.x, self.y))
+ i += 1
+ if image_batch and label_batch:
+ image_batch = normalize(np.array(image_batch))
+ label_batch = normalize(np.array(label_batch))
+ yield (image_batch, label_batch)
+
+ def get_num_data_points(self, train=False, val=False):
+ try:
+ image_file_list = self.x_train_file_list if val is False and train is True else self.x_val_file_list
+ except ValueError:
+ print('one of train or val need to be True')
+
+ return len(image_file_list)
+
+ def shuffle_image_label_lists_together(self):
+ combined = list(zip(self.image_file_list, self.label_file_list))
+ random.shuffle(combined)
+ return zip(*combined)
+
+ @staticmethod
+ def change_color_space(image, label, color_space):
+ color_space = color_space.lower()
+ if color_space == 'hsi' or color_space == 'hsv':
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+ label = cv2.cvtColor(label, cv2.COLOR_BGR2HSV)
+ elif color_space == 'lab':
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
+ label = cv2.cvtColor(label, cv2.COLOR_BGR2LAB)
+ return image, label
+def normalize(arr):
+ diff = np.amax(arr) - np.amin(arr)
+ diff = 255 if diff == 0 else diff
+ arr = arr / np.absolute(diff)
+ return arr
+
+
+def get_png_filename_list(path):
+ file_list = []
+ for FileNameLength in range(0, 500):
+ for dirName, subdirList, fileList in os.walk(path):
+ for filename in fileList:
+ # check file extension
+ if ".png" in filename.lower() and len(filename) == FileNameLength:
+ file_list.append(filename)
+ break
+ file_list.sort()
+ return file_list
+
+
+def get_jpg_filename_list(path):
+ file_list = []
+ for FileNameLength in range(0, 500):
+ for dirName, subdirList, fileList in os.walk(path):
+ for filename in fileList:
+ # check file extension
+ if ".jpg" in filename.lower() and len(filename) == FileNameLength:
+ file_list.append(filename)
+ break
+ file_list.sort()
+ return file_list
+
+
+def load_jpg_images(path):
+ file_list = get_jpg_filename_list(path)
+ temp_list = []
+ for filename in file_list:
+ img = cv2.imread(path + filename, 1)
+ temp_list.append(img.astype("float32"))
+
+ temp_list = np.array(temp_list)
+ # x_train = np.reshape(x_train,(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))
+ return temp_list, file_list
+
+
+def load_png_images(path):
+
+ temp_list = []
+ file_list = get_png_filename_list(path)
+ for filename in file_list:
+ img = cv2.imread(path + filename, 1)
+ temp_list.append(img.astype("float32"))
+
+ temp_list = np.array(temp_list)
+ #temp_list = np.reshape(temp_list,(temp_list.shape[0], temp_list.shape[1], temp_list.shape[2], 3))
+ return temp_list, file_list
+
+
+def load_data(path):
+ # path_train_images = path + "train/images/padded/"
+ # path_train_labels = path + "train/labels/padded/"
+ # path_test_images = path + "test/images/padded/"
+ # path_test_labels = path + "test/labels/padded/"
+ path_train_images = path + "train/images/"
+ path_train_labels = path + "train/labels/"
+ path_test_images = path + "test/images/"
+ path_test_labels = path + "test/labels/"
+ x_train, train_image_filenames_list = load_png_images(path_train_images)
+ y_train, train_label_filenames_list = load_png_images(path_train_labels)
+ x_test, test_image_filenames_list = load_png_images(path_test_images)
+ y_test, test_label_filenames_list = load_png_images(path_test_labels)
+ x_train = normalize(x_train)
+ y_train = normalize(y_train)
+ x_test = normalize(x_test)
+ y_test = normalize(y_test)
+ return x_train, y_train, x_test, y_test, test_label_filenames_list
+
+
+def load_test_images(path):
+ path_test_images = path + "test/images/"
+ x_test, test_image_filenames_list = load_png_images(path_test_images)
+ x_test = normalize(x_test)
+ return x_test, test_image_filenames_list
+
+
+def save_results(np_array, color_space, outpath, test_label_filenames_list):
+ i = 0
+ for filename in test_label_filenames_list:
+ # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
+ pred = np_array[i]
+ # if color_space.lower() is 'hsi' or 'hsv':
+ # pred = cv2.cvtColor(pred, cv2.COLOR_HSV2RGB)
+ # elif color_space.lower() is 'lab':
+ # pred = cv2.cvtColor(pred, cv2.COLOR_Lab2RGB)
+ cv2.imwrite(outpath + filename, pred * 255.)
+ i += 1
+
+
+def save_rgb_results(np_array, outpath, test_label_filenames_list):
+ i = 0
+ for filename in test_label_filenames_list:
+ # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1]))
+ cv2.imwrite(outpath + filename, np_array[i] * 255.)
+ i += 1
+
+
+def save_history(model, model_name, training_history, dataset, n_filters, epoch, learning_rate, loss,
+ color_space, path=None, temp_name=None):
+ save_weight_filename = temp_name if temp_name else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ model.save('{}{}.hdf5'.format(path, save_weight_filename))
+ with open('{}{}.json'.format(path, save_weight_filename), 'w') as f:
+ json.dump(training_history.history, f, indent=2)
+
+ json_list = ['{}{}.json'.format(path, save_weight_filename)]
+ for json_filename in json_list:
+ with open(json_filename) as f:
+ # convert the loss json object to a python dict
+ loss_dict = json.load(f)
+ print_list = ['loss', 'val_loss', 'dice_coef', 'val_dice_coef']
+ for item in print_list:
+ item_list = []
+ if item in loss_dict:
+ item_list.extend(loss_dict.get(item))
+ plt.plot(item_list)
+ plt.title('model:{} lr:{} epoch:{} #filtr:{} Colorspaces:{}'.format(model_name, learning_rate,
+ epoch, n_filters, color_space))
+ plt.ylabel('loss')
+ plt.xlabel('epoch')
+ plt.legend(['train_loss', 'test_loss', 'train_dice', 'test_dice'], loc='upper left')
+ plt.savefig('{}{}.png'.format(path, save_weight_filename))
+ plt.show()
+ plt.clf()
+
+
+
diff --git a/utils/io/read.py b/utils/io/read.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f41724e6e8ed5c7e44dcdaeb5d37d081cd5fe8
--- /dev/null
+++ b/utils/io/read.py
@@ -0,0 +1,173 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/io/read.py
+# author : CM
+# Function to read dataset
+#
+# ------------------------------------------------------------ #
+
+import os
+import sys
+
+import nibabel as nib
+import numpy as np
+
+# read nii file and load it into a numpy 3d array
+def niiToNp(filename):
+ data = nib.load(filename).get_data().astype('float16')
+ return data/data.max()
+
+# read a dataset and load it into a numpy 4d array
+def readDataset(folder, size, size_x, size_y, size_z):
+ dataset = np.empty((size, size_x, size_y, size_z), dtype='float16')
+ i = 0
+ files = os.listdir(folder)
+ files.sort()
+ for filename in files:
+ if(i>=size):
+ break
+ print(filename)
+ dataset[i, :, :, :] = niiToNp(os.path.join(folder, filename))
+ i = i+1
+
+ return dataset
+
+# return dataset affine
+def getAffine_subdir(folder):
+ subdir = os.listdir(folder)
+ subdir.sort()
+ files = os.listdir(folder+subdir[0])
+ path = folder + subdir[0]
+ image = nib.load(os.path.join(path, files[0]))
+ return image.affine
+
+def getAffine(folder):
+ files = os.listdir(folder)
+ files.sort()
+ image = nib.load(os.path.join(folder, files[0]))
+ return image.affine
+
+
+
+# reshape the dataset to match keras input shape (add channel dimension)
+def reshapeDataset(d):
+ return d.reshape(d.shape[0], d.shape[1], d.shape[2], d.shape[3], 1)
+
+# read a dataset and load it into a numpy 3d array as raw data (no normalisation)
+def readRawDataset(folder, size, size_x, size_y, size_z, dtype):
+ files = os.listdir(folder)
+ files.sort()
+
+ if(len(files) < size):
+ sys.exit(2)
+
+ count = 0
+ # astype depend on your dataset type.
+ dataset = np.empty((size, size_x, size_y, size_z)).astype(dtype)
+
+ for filename in files:
+ if(count>=size):
+ break
+ dataset[count, :, :, :] = nib.load(os.path.join(folder, filename)).get_data()
+ count += 1
+ print(count, '/', size, os.path.join(folder, filename))
+
+ return dataset
+
+def readTrainValid(config):
+ print("Loading training dataset")
+
+ train_gd_dataset = readRawDataset(config["dataset_train_gd_path"],
+ config["dataset_train_size"],
+ config["image_size_x"],
+ config["image_size_y"],
+ config["image_size_z"],
+ 'uint16')
+
+ print("Training ground truth dataset shape", train_gd_dataset.shape)
+ print("Training ground truth dataset dtype", train_gd_dataset.dtype)
+
+ train_in_dataset = readRawDataset(config["dataset_train_mra_path"],
+ config["dataset_train_size"],
+ config["image_size_x"],
+ config["image_size_y"],
+ config["image_size_z"],
+ 'uint16')
+
+ print("Training input image dataset shape", train_in_dataset.shape)
+ print("Training input image dataset dtype", train_in_dataset.dtype)
+
+ print("Loading validation dataset")
+
+ valid_gd_dataset = readRawDataset(config["dataset_valid_gd_path"],
+ config["dataset_valid_size"],
+ config["image_size_x"],
+ config["image_size_y"],
+ config["image_size_z"],
+ 'uint16')
+
+ print("Validation ground truth dataset shape", valid_gd_dataset.shape)
+ print("Validation ground truth dataset dtype", valid_gd_dataset.dtype)
+
+ valid_in_dataset = readRawDataset(config["dataset_valid_mra_path"],
+ config["dataset_valid_size"],
+ config["image_size_x"],
+ config["image_size_y"],
+ config["image_size_z"],
+ 'uint16')
+
+ print("Validation input image dataset shape", valid_in_dataset.shape)
+ print("Validation input image dataset dtype", valid_in_dataset.dtype)
+ return train_gd_dataset, train_in_dataset, valid_gd_dataset, valid_in_dataset
+
+# read a dataset and load it into a numpy 3d without any preprocessing
+def getDataset(folder, size, type=None):
+ files = os.listdir(folder)
+ files.sort()
+
+ if(len(files) < size):
+ sys.exit(0x2001)
+
+ image = nib.load(os.path.join(folder, files[0]))
+
+ if type==None:
+ dtype = image.get_data_dtype()
+ else:
+ dtype = type
+
+ dataset = np.empty((size, image.shape[0], image.shape[1], image.shape[2])).astype(dtype)
+ del image
+
+ count = 0
+ for filename in files:
+ dataset[count, :, :, :] = nib.load(os.path.join(folder, filename)).get_data()
+ count += 1
+ if(count>=size):
+ break
+
+ return dataset
+
+# read a dataset and load it into a numpy 3d without any preprocessing with "start" index and number of files
+def readDatasetPart(folder, start, size, type=None):
+ files = os.listdir(folder)
+ files.sort()
+
+ if(len(files) < start + size):
+ sys.exit("readDatasetPart : len(files) < start + size")
+
+ image = nib.load(os.path.join(folder, files[0]))
+
+ if type==None:
+ dtype = image.get_data_dtype()
+ else:
+ dtype = type
+
+ dataset = np.empty(((size), image.shape[0], image.shape[1], image.shape[2])).astype(dtype)
+ del image
+
+ count = 0
+ for i in range(start, start + size):
+ dataset[count, :, :, :] = nib.load(os.path.join(folder, files[i])).get_data()
+ count += 1
+
+ return dataset
diff --git a/utils/io/write.py b/utils/io/write.py
new file mode 100644
index 0000000000000000000000000000000000000000..88cf8386a62618aba4c6b2dfdfb5fe7bb5f852ec
--- /dev/null
+++ b/utils/io/write.py
@@ -0,0 +1,23 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/io/write.py
+# author : CM
+# Function to write results
+#
+# ------------------------------------------------------------ #
+
+import nibabel as nib
+import numpy as np
+
+# write nii file from a numpy 3d array
+def npToNii(data, filename):
+ axes = np.eye(4)
+ axes[0][0] = -1
+ axes[1][1] = -1
+ image = nib.Nifti1Image(data, axes)
+ nib.save(image, filename)
+
+# write nii file from a numpy 3d array with affine configuration
+def npToNiiAffine(data, affine, filename):
+ image = nib.Nifti1Image(data, affine)
+ nib.save(image, filename)
\ No newline at end of file
diff --git a/utils/learning/__pycache__/losses.cpython-37.pyc b/utils/learning/__pycache__/losses.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..383ef9c84b97d53194604da9da2701b7ed5bb941
Binary files /dev/null and b/utils/learning/__pycache__/losses.cpython-37.pyc differ
diff --git a/utils/learning/__pycache__/losses.cpython-39.pyc b/utils/learning/__pycache__/losses.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39ce00e24787467c070845a5c5d27a8188338831
Binary files /dev/null and b/utils/learning/__pycache__/losses.cpython-39.pyc differ
diff --git a/utils/learning/__pycache__/metrics.cpython-310.pyc b/utils/learning/__pycache__/metrics.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47e8e183fe484542b41ba863d6a4a91e6ea1361c
Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-310.pyc differ
diff --git a/utils/learning/__pycache__/metrics.cpython-313.pyc b/utils/learning/__pycache__/metrics.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbb302cc28e50af6d1ec38630dd3608290f0b30a
Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-313.pyc differ
diff --git a/utils/learning/__pycache__/metrics.cpython-37.pyc b/utils/learning/__pycache__/metrics.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5050e934bcd9313efbaa2e1f736307a43650e0f6
Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-37.pyc differ
diff --git a/utils/learning/__pycache__/metrics.cpython-39.pyc b/utils/learning/__pycache__/metrics.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a71265dbb00586a0cbdefa64882c085c0b81ae3
Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-39.pyc differ
diff --git a/utils/learning/callbacks.py b/utils/learning/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6726caf89120752c885b50d57f3c80a189efe1ab
--- /dev/null
+++ b/utils/learning/callbacks.py
@@ -0,0 +1,18 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/learning/callbacks.py
+# author : CM
+# Custom callbacks
+#
+# ------------------------------------------------------------ #
+
+import numpy as np
+from keras.callbacks import LearningRateScheduler
+
+# reduce learning rate on each epoch
+def learningRateSchedule(initialLr=1e-4, decayFactor=0.99, stepSize=1):
+ def schedule(epoch):
+ lr = initialLr * (decayFactor ** np.floor(epoch / stepSize))
+ print("Learning rate : ", lr)
+ return lr
+ return LearningRateScheduler(schedule)
\ No newline at end of file
diff --git a/utils/learning/losses.py b/utils/learning/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..d811986771010537ef591126fad399c2a4cfa29f
--- /dev/null
+++ b/utils/learning/losses.py
@@ -0,0 +1,136 @@
+# ------------------------------------------------------------ #
+#
+# file : losses.py
+# author : CM
+# Loss function
+#
+# ------------------------------------------------------------ #
+import keras.backend as K
+
+def dice_coef(y_true, y_pred, smooth=1):
+ y_true_f = K.flatten(y_true)
+ y_pred_f = K.flatten(y_pred)
+ intersection = K.sum(y_true_f * y_pred_f)
+ return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+
+def dice_coef_loss(y_true, y_pred):
+ return -dice_coef(y_true, y_pred)
+
+# Jaccard distance
+def jaccard_distance_loss(y_true, y_pred, smooth=100):
+ """
+ Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
+ = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
+
+ The jaccard distance loss is usefull for unbalanced datasets. This has been
+ shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
+ gradient.
+
+ Ref: https://en.wikipedia.org/wiki/Jaccard_index
+
+ @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
+ @author: wassname
+ """
+ intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
+ sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
+ jac = (intersection + smooth) / (sum_ - intersection + smooth)
+ return (1 - jac) * smooth
+
+
+def dice_coef_(y_true, y_pred, smooth=1):
+ """
+ Dice = (2*|X & Y|)/ (|X|+ |Y|)
+ = 2*sum(|A*B|)/(sum(A^2)+sum(B^2))
+ ref: https://arxiv.org/pdf/1606.04797v1.pdf
+ """
+ intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
+ return (2. * intersection + smooth) / (K.sum(K.square(y_true), -1) + K.sum(K.square(y_pred), -1) + smooth)
+
+def dice_coef_loss_(y_true, y_pred):
+ return 1 - dice_coef_(y_true, y_pred)
+
+'''
+def dice_loss(y_true, y_pred, smooth=1e-6):
+ """ Loss function base on dice coefficient.
+
+ Parameters
+ ----------
+ y_true : keras tensor
+ tensor containing target mask.
+ y_pred : keras tensor
+ tensor containing predicted mask.
+ smooth : float
+ small real value used for avoiding division by zero error.
+
+ Returns
+ -------
+ keras tensor
+ tensor containing dice loss.
+ """
+ y_true_f = K.flatten(y_true)
+ y_pred_f = K.flatten(y_pred)
+ intersection = K.sum(y_true_f * y_pred_f)
+ answer = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+ return -answer
+'''
+# the deeplab version of dice_loss
+def dice_loss(y_true, y_pred):
+ smooth = 1.
+ y_true_f = K.flatten(y_true)
+ y_pred_f = K.flatten(y_pred)
+ intersection = y_true_f * y_pred_f
+ score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+ return 1. - score
+
+def tversky_loss(y_true, y_pred, alpha=0.3, beta=0.7, smooth=1e-10):
+ """ Tversky loss function.
+
+ Parameters
+ ----------
+ y_true : keras tensor
+ tensor containing target mask.
+ y_pred : keras tensor
+ tensor containing predicted mask.
+ alpha : float
+ real value, weight of '0' class.
+ beta : float
+ real value, weight of '1' class.
+ smooth : float
+ small real value used for avoiding division by zero error.
+
+ Returns
+ -------
+ keras tensor
+ tensor containing tversky loss.
+ """
+ y_true = K.flatten(y_true)
+ y_pred = K.flatten(y_pred)
+ truepos = K.sum(y_true * y_pred)
+ fp_and_fn = alpha * K.sum(y_pred * (1 - y_true)) + beta * K.sum((1 - y_pred) * y_true)
+ answer = (truepos + smooth) / ((truepos + smooth) + fp_and_fn)
+ return -answer
+
+def jaccard_coef_logloss(y_true, y_pred, smooth=1e-10):
+ """ Loss function based on jaccard coefficient.
+
+ Parameters
+ ----------
+ y_true : keras tensor
+ tensor containing target mask.
+ y_pred : keras tensor
+ tensor containing predicted mask.
+ smooth : float
+ small real value used for avoiding division by zero error.
+
+ Returns
+ -------
+ keras tensor
+ tensor containing negative logarithm of jaccard coefficient.
+ """
+ y_true = K.flatten(y_true)
+ y_pred = K.flatten(y_pred)
+ truepos = K.sum(y_true * y_pred)
+ falsepos = K.sum(y_pred) - truepos
+ falseneg = K.sum(y_true) - truepos
+ jaccard = (truepos + smooth) / (smooth + truepos + falseneg + falsepos)
+ return -K.log(jaccard + smooth)
diff --git a/utils/learning/metrics.py b/utils/learning/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e5789bf29f501a13550d7da136669239e7cc317
--- /dev/null
+++ b/utils/learning/metrics.py
@@ -0,0 +1,81 @@
+# ------------------------------------------------------------ #
+#
+# file : metrics.py
+# author : CM
+# Metrics for evaluation
+#
+# ------------------------------------------------------------ #
+from keras import backend as K
+
+
+# dice coefficient
+'''
+def dice_coef(y_true, y_pred, smooth=1.):
+ y_true_f = K.flatten(y_true)
+ y_pred_f = K.flatten(y_pred)
+ intersection = K.sum(y_true_f * y_pred_f)
+ return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+'''
+# the deeplab version of dice coefficient
+def dice_coef(y_true, y_pred):
+ smooth = 0.00001
+ y_true_f = K.flatten(y_true)
+ y_pred = K.cast(y_pred, 'float32')
+ y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
+ intersection = y_true_f * y_pred_f
+ score = (2. * K.sum(intersection) + smooth) / ((K.sum(y_true_f) + K.sum(y_pred_f)) + smooth)
+ return score
+
+
+# Recall (true positive rate)
+def recall(truth, prediction):
+ TP = K.sum(K.round(K.clip(truth * prediction, 0, 1)))
+ P = K.sum(K.round(K.clip(truth, 0, 1)))
+ return TP / (P + K.epsilon())
+
+
+# Specificity (true negative rate)
+def specificity(truth, prediction):
+ TN = K.sum(K.round(K.clip((1-truth) * (1-prediction), 0, 1)))
+ N = K.sum(K.round(K.clip(1-truth, 0, 1)))
+ return TN / (N + K.epsilon())
+
+
+# Precision (positive prediction value)
+def precision(truth, prediction):
+ TP = K.sum(K.round(K.clip(truth * prediction, 0, 1)))
+ FP = K.sum(K.round(K.clip((1-truth) * prediction, 0, 1)))
+ return TP / (TP + FP + K.epsilon())
+
+
+def f1(y_true, y_pred):
+ def recall(y_true, y_pred):
+ """Recall metric.
+
+ Only computes a batch-wise average of recall.
+
+ Computes the recall, a metric for multi-label classification of
+ how many relevant items are selected.
+ """
+ true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+ possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
+ recall = true_positives / (possible_positives + K.epsilon())
+ return recall
+
+ def precision(y_true, y_pred):
+ """Precision metric.
+
+ Only computes a batch-wise average of precision.
+
+ Computes the precision, a metric for multi-label classification of
+ how many selected items are relevant.
+ """
+ true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+ predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
+ precision = true_positives / (predicted_positives + K.epsilon())
+ return precision
+
+ precision = precision(y_true, y_pred)
+ recall = recall(y_true, y_pred)
+
+ return 2*((precision*recall)/(precision+recall+K.epsilon()))
diff --git a/utils/learning/patch/extraction.py b/utils/learning/patch/extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..d260a558a0a85f12b564100a12645295c24902b7
--- /dev/null
+++ b/utils/learning/patch/extraction.py
@@ -0,0 +1,279 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/learning/patch/extraction.py
+# author : CM
+# Function to extract patch from input dataset
+#
+# ------------------------------------------------------------ #
+import sys
+from random import randint
+import numpy as np
+from keras.utils import to_categorical
+
+# ----- Patch Extraction -----
+# -- Single Patch
+# exctract a patch from an image
+def extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z):
+ patch = d[x:x+patch_size_x,y:y+patch_size_y,z:z+patch_size_z]
+ return patch
+
+# extract a patch from an image. The patch can be out of the image (0 padding)
+def extractPatchOut(d, patch_size_x, patch_size_y, patch_size_z, x_, y_, z_):
+ patch = np.zeros((patch_size_x, patch_size_y, patch_size_z), dtype='float16')
+ for x in range(0,patch_size_x):
+ for y in range(0, patch_size_y):
+ for z in range(0, patch_size_z):
+ if(x+x_ >= 0 and x+x_ < d.shape[0] and y+y_ >= 0 and y+y_ < d.shape[1] and z+z_ >= 0 and z+z_ < d.shape[2]):
+ patch[x,y,z] = d[x+x_,y+y_,z+z_]
+ return patch
+
+# create random patch for an image
+def generateRandomPatch(d, patch_size_x, patch_size_y, patch_size_z):
+ x = randint(0, d.shape[0]-patch_size_x)
+ y = randint(0, d.shape[1]-patch_size_y)
+ z = randint(0, d.shape[2]-patch_size_z)
+ data = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ return data
+
+# -- Multiple Patchs
+# create random patchs for an image
+def generateRandomPatchs(d, patch_size_x, patch_size_y, patch_size_z, patch_number):
+ # max_patch_nb = (d.shape[0]-patch_size_x)*(d.shape[1]-patch_size_y)*(d.shape[2]-patch_size_z)
+ data = np.empty((patch_number, patch_size_x, patch_size_y, patch_size_z), dtype='float16')
+
+ for i in range(0,patch_number):
+ data[i] = generateRandomPatch(d, patch_size_x, patch_size_y, patch_size_z)
+
+ return data
+
+# divide the full image into patchs
+# todo : missing data if shape%patch_size is not 0
+def generateFullPatchs(d, patch_size_x, patch_size_y, patch_size_z):
+ patch_nb = int((d.shape[0]/patch_size_x)*(d.shape[1]/patch_size_y)*(d.shape[2]/patch_size_z))
+ data = np.empty((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16')
+ i = 0
+ for x in range(0,d.shape[0], patch_size_x):
+ for y in range(0, d.shape[1], patch_size_y):
+ for z in range(0,d.shape[2], patch_size_z):
+ data[i] = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ i = i+1
+
+ return data
+
+def generateFullPatchsPlus(d, patch_size_x, patch_size_y, patch_size_z, dx, dy, dz):
+ patch_nb = int((d.shape[0]/dx)*(d.shape[1]/dy)*(d.shape[2]/dz))
+ data = np.empty((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16')
+ i = 0
+ for x in range(0,d.shape[0]-dx, dx):
+ for y in range(0, d.shape[1]-dy, dy):
+ for z in range(0,d.shape[2]-dz, dz):
+ data[i] = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ i = i+1
+
+ return data
+
+def noNeg(x):
+ if(x>0):
+ return x
+ else:
+ return 0
+
+def generateFullPatchsCentered(d, patch_size_x, patch_size_y, patch_size_z):
+ patch_nb = int(2*(d.shape[0]/patch_size_x)*2*(d.shape[1]/patch_size_y)*2*(d.shape[2]/patch_size_z))
+ data = np.zeros((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16')
+ i = 0
+ psx = int(patch_size_x/2)
+ psy = int(patch_size_y/2)
+ psz = int(patch_size_z/2)
+ for x in range(-int(patch_size_x/4),d.shape[0]-3*int(patch_size_x/4)+1, psx):
+ for y in range(-int(patch_size_y/4), d.shape[1]-3*int(patch_size_y/4)+1, psy):
+ for z in range(-int(patch_size_z/4),d.shape[2]-3*int(patch_size_z/4)+1, psz):
+ # patch = np.zeros((psx,psy,psz), dtype='float16')
+ # patch = d[noNeg(x):noNeg(x)+patch_size_x,noNeg(y):noNeg(y)+patch_size_y,noNeg(z):noNeg(z)+patch_size_z]
+ patch = extractPatchOut(d,patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ data[i] = patch
+ i = i+1
+ return data
+
+# ----- Patch Extraction Generator -----
+# Generator of random patchs of size 32*32*32
+def generatorRandomPatchs(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z):
+ batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype='float16')
+ batch_labels = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, labels.shape[4]), dtype='float16')
+
+ while True:
+ for i in range(batch_size):
+ id = randint(0,features.shape[0]-1)
+ x = randint(0, features.shape[1]-patch_size_x)
+ y = randint(0, features.shape[2]-patch_size_y)
+ z = randint(0, features.shape[3]-patch_size_z)
+
+ batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ batch_labels[i] = extractPatch(labels[id], patch_size_x, patch_size_y, patch_size_z, x, y, z)
+
+ yield batch_features, batch_labels
+
+# Generator of random patchs of size 32*32*32 and 16*16*16
+def generatorRandomPatchs3216(features, labels, batch_size):
+ batch_features = np.zeros((batch_size, 32, 32, 32, features.shape[4]), dtype='float16')
+ batch_labels = np.zeros((batch_size, 16, 16, 16, labels.shape[4]), dtype='float16')
+
+ while True:
+ for i in range(batch_size):
+ id = randint(0,features.shape[0]-1)
+ x = randint(0, features.shape[1]-32)
+ y = randint(0, features.shape[2]-32)
+ z = randint(0, features.shape[3]-32)
+
+ batch_features[i] = extractPatch(features[id], 32, 32, 32, x, y, z)
+ batch_labels[i] = extractPatch(labels[id], 16, 16, 16, x+16, y+16, z+16)
+
+ yield batch_features, batch_labels
+
+def generatorRandomPatchsLabelCentered(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z):
+ patch_centered_size_x = int(patch_size_x/2)
+ patch_centered_size_y = int(patch_size_y/2)
+ patch_centered_size_z = int(patch_size_z/2)
+
+ batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype=features.dtype)
+ batch_labels = np.zeros((batch_size, patch_centered_size_x, patch_centered_size_y, patch_centered_size_z,
+ labels.shape[4]), dtype=labels.dtype)
+
+ while True:
+ for i in range(batch_size):
+ id = randint(0,features.shape[0]-1)
+ x = randint(0, features.shape[1]-patch_size_x)
+ y = randint(0, features.shape[2]-patch_size_y)
+ z = randint(0, features.shape[3]-patch_size_z)
+
+ batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ batch_labels[i] = extractPatch(labels[id], patch_centered_size_x, patch_centered_size_y, patch_centered_size_z,
+ int(x+patch_size_x/4), int(y+patch_size_y/4), int(z+patch_size_z/4))
+
+ yield batch_features, batch_labels
+
+def generatorRandomPatchsDolz(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z):
+ batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype=features.dtype)
+ batch_labels = np.zeros((batch_size, int(patch_size_x / 2) * int(patch_size_y / 2) * int(patch_size_z / 2), 2), dtype=labels.dtype)
+
+ while True:
+ for i in range(batch_size):
+ id = randint(0,features.shape[0]-1)
+ x = randint(0, features.shape[1]-patch_size_x)
+ y = randint(0, features.shape[2]-patch_size_y)
+ z = randint(0, features.shape[3]-patch_size_z)
+
+ batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z)
+ tmpPatch = extractPatch(labels[id], int(patch_size_x/2), int(patch_size_y/2), int(patch_size_z/2),
+ int(x+patch_size_x/4), int(y+patch_size_y/4), int(z+patch_size_z/4))
+ batch_labels[i] = to_categorical(tmpPatch.flatten(),2)
+ """
+ count = 0
+ for x in range(0, tmpPatch.shape[0]):
+ for y in range(0, tmpPatch.shape[1]):
+ for z in range(0, tmpPatch.shape[2]):
+ if(tmpPatch[x,y,z,0] == 1):
+ batch_labels[i,count,1] = 1
+ else:
+ batch_labels[i,count,0] = 1
+ count += 1
+ """
+ yield batch_features, batch_labels
+
+from scipy.ndimage import zoom, rotate
+# Generate random patchs with random linear transformation
+# translation (random position) rotation, scale
+# Preconditions : patch_features_ % patch_labels_ = 0
+# patch_features_ >= patch_labels_
+# todo : scale
+def generatorRandomPatchsLinear(features, labels, patch_features_x, patch_features_y, patch_features_z,
+ patch_labels_x, patch_labels_y, patch_labels_z):
+
+ patch_features = np.zeros((1, patch_features_x, patch_features_y, patch_features_z, features.shape[4]), dtype=features.dtype)
+ patch_labels = np.zeros((1, patch_labels_x, patch_labels_y, patch_labels_z, labels.shape[4]), dtype=labels.dtype)
+
+ if(patch_features_x % patch_labels_x != 0 or patch_features_y % patch_labels_y != 0 or patch_features_z % patch_labels_z != 0):
+ sys.exit(0x00F0)
+
+ if(patch_features_x < patch_labels_x or patch_features_y < patch_labels_y or patch_features_z < patch_labels_z):
+ sys.exit(0x00F1)
+
+ # middle of patch
+ mx = int(patch_features_x/2)
+ my = int(patch_features_y/2)
+ mz = int(patch_features_z/2)
+ # patch label size/2
+ sx = int(patch_labels_x / 2)
+ sy = int(patch_labels_y / 2)
+ sz = int(patch_labels_z / 2)
+
+ while True:
+ id = randint(0, features.shape[0]-1)
+ x = randint(0, features.shape[1]-patch_features_x)
+ y = randint(0, features.shape[2]-patch_features_y)
+ z = randint(0, features.shape[3]-patch_features_z)
+
+ # todo : check time consumtion and rotation directly on complete image
+ r0 = randint(0, 360)-180
+ r1 = randint(0, 360)-180
+ r2 = randint(0, 360)-180
+ rot_features = rotate(input=features[0], angle=r0, axes=(0, 1), reshape=False)
+ rot_features = rotate(input=rot_features, angle=r1, axes=(1, 2), reshape=False)
+ rot_features = rotate(input=rot_features, angle=r2, axes=(2, 0), reshape=False)
+ rot_labels = rotate(input=labels[0], angle=r0, axes=(0, 1), reshape=False)
+ rot_labels = rotate(input=rot_labels, angle=r1, axes=(1, 2), reshape=False)
+ rot_labels = rotate(input=rot_labels, angle=r2, axes=(2, 0), reshape=False)
+
+ patch_features[0] = extractPatch(rot_features, patch_features_x, patch_features_y, patch_features_z, x, y, z)
+
+ patch_labels[0] = extractPatch(rot_labels, patch_labels_x, patch_labels_y, patch_labels_z,
+ x + mx - sx, y + my - sy, z + mz - sz)
+
+ yield patch_features, patch_labels
+
+def randomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size):
+ patchs_in = np.zeros((patch_number, patch_in_size[0], patch_in_size[1], patch_in_size[2]), dtype=in_dataset.dtype)
+ patchs_gd = np.zeros((patch_number, patch_gd_size[0], patch_gd_size[1], patch_gd_size[2]), dtype=gd_dataset.dtype)
+
+ if(patch_in_size[0] % patch_gd_size[0] != 0 or patch_in_size[1] % patch_gd_size[1] != 0 or patch_in_size[2] % patch_gd_size[2] != 0):
+ sys.exit("ERROR : randomPatchsAugmented patchs size error 1")
+
+ if(patch_in_size[0] < patch_gd_size[0] or patch_in_size[1] < patch_gd_size[1] or patch_in_size[2] < patch_gd_size[2]):
+ sys.exit("ERROR : randomPatchsAugmented patchs size error 2")
+
+ # middle of patch
+ mx = int(patch_in_size[0] / 2)
+ my = int(patch_in_size[1] / 2)
+ mz = int(patch_in_size[2] / 2)
+ # patch label size/2
+ sx = int(patch_gd_size[0] / 2)
+ sy = int(patch_gd_size[1] / 2)
+ sz = int(patch_gd_size[2] / 2)
+
+ for count in range(patch_number):
+ id = randint(0, in_dataset.shape[0]-1)
+ x = randint(0, in_dataset.shape[1]-patch_in_size[0])
+ y = randint(0, in_dataset.shape[2]-patch_in_size[1])
+ z = randint(0, in_dataset.shape[3]-patch_in_size[2])
+
+ r0 = randint(0, 3)
+ r1 = randint(0, 3)
+ r2 = randint(0, 3)
+
+ patchs_in[count] = extractPatch(in_dataset[id], patch_in_size[0], patch_in_size[1], patch_in_size[2], x, y, z)
+ patchs_gd[count] = extractPatch(gd_dataset[id], patch_gd_size[0], patch_gd_size[1], patch_gd_size[2], x + mx - sx, y + my - sy, z + mz - sz)
+
+ patchs_in[count] = np.rot90(patchs_in[count], r0, (0, 1))
+ patchs_in[count] = np.rot90(patchs_in[count], r1, (1, 2))
+ patchs_in[count] = np.rot90(patchs_in[count], r2, (2, 0))
+
+ patchs_gd[count] = np.rot90(patchs_gd[count], r0, (0, 1))
+ patchs_gd[count] = np.rot90(patchs_gd[count], r1, (1, 2))
+ patchs_gd[count] = np.rot90(patchs_gd[count], r2, (2, 0))
+
+ return patchs_in.reshape(patchs_in.shape[0], patchs_in.shape[1], patchs_in.shape[2], patchs_in.shape[3], 1),\
+ patchs_gd.reshape(patchs_gd.shape[0], patchs_gd.shape[1], patchs_gd.shape[2], patchs_gd.shape[3], 1)
+
+def generatorRandomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size):
+ while True:
+ yield randomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size)
diff --git a/utils/learning/patch/reconstruction.py b/utils/learning/patch/reconstruction.py
new file mode 100644
index 0000000000000000000000000000000000000000..d92cb504068403c27e079018b0b680d0f76dd252
--- /dev/null
+++ b/utils/learning/patch/reconstruction.py
@@ -0,0 +1,56 @@
+# ------------------------------------------------------------ #
+#
+# file : utils/learning/patch/reconstruction.py
+# author : CM
+# Function to reconstruct image from patch
+#
+# ------------------------------------------------------------
+
+import numpy as np
+
+# ----- Image Reconstruction -----
+# Recreate the image from patchs
+def fullPatchsToImage(image,patchs):
+ i = 0
+ for x in range(0,image.shape[0], patchs.shape[1]):
+ for y in range(0, image.shape[1], patchs.shape[2]):
+ for z in range(0,image.shape[2], patchs.shape[3]):
+ image[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] = patchs[i,:,:,:,0]
+ i = i+1
+ return image
+
+def fullPatchsPlusToImage(image,patchs, dx, dy, dz):
+ div = np.zeros(image.shape)
+ one = np.ones((patchs.shape[1],patchs.shape[2],patchs.shape[3]))
+
+ i = 0
+ for x in range(0,image.shape[0]-dx, dx):
+ for y in range(0, image.shape[1]-dy, dy):
+ for z in range(0,image.shape[2]-dz, dz):
+ div[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] += one
+ image[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] = patchs[i,:,:,:,0]
+ i = i+1
+
+ image = image/div
+
+ return image
+
+def dolzReconstruction(image,patchs):
+ output = np.copy(image)
+
+ count = 0
+
+ print("image shape", image.shape)
+ print("patchs shape", patchs.shape)
+
+ # todo : change 16 with patch shape
+
+ for x in range(0,image.shape[0], 16):
+ for y in range(0, image.shape[1], 16):
+ for z in range(0,image.shape[2], 16):
+ patch = np.argmax(patchs[count], axis=1)
+ patch = patch.reshape(16, 16, 16)
+ output[x:x+patch.shape[0],y:y+patch.shape[1],z:z+patch.shape[2]] = patch
+ count += 1
+
+ return output
\ No newline at end of file
diff --git a/utils/padding.py b/utils/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..567a94de5478f58641d138609bd656bcd5fc7f3b
--- /dev/null
+++ b/utils/padding.py
@@ -0,0 +1,50 @@
+import numpy as np
+from PIL import Image
+import os
+
+
+def paddingjpg(path):
+ xmax = 0
+ ymax = 0
+ file_list = []
+ for FileNameLength in range(0, 100):
+ for dirName, subdirList, fileList in os.walk(path):
+ for filename in fileList:
+ # check file extension
+ if ".jpg" in filename.lower() and len(filename) == FileNameLength:
+ file_list.append(filename)
+ break
+ file_list.sort()
+ print(file_list)
+ temp_list = []
+ for filename in file_list:
+ image = Image.open(path + filename)
+ padded_image = Image.new("RGB", [560, 560])
+ padded_image.paste(image, (0,0))
+ padded_image.save(path + 'padded/' + filename)
+
+def paddingpng(path):
+ xmax = 0
+ ymax = 0
+ file_list = []
+ for FileNameLength in range(0, 100):
+ for dirName, subdirList, fileList in os.walk(path):
+ for filename in fileList:
+ # check file extension
+ if ".png" in filename.lower() and len(filename) == FileNameLength:
+ file_list.append(filename)
+ break
+ file_list.sort()
+ print(file_list)
+ temp_list = []
+ for filename in file_list:
+ image = Image.open(path + filename)
+ padded_image = Image.new("L", [560, 560])
+ padded_image.paste(image, (0,0))
+ padded_image.save(path + 'padded/' + filename)
+
+
+paddingjpg('../data/train/images/')
+paddingpng('../data/train/labels/')
+paddingjpg('../data/test/images/')
+paddingpng('../data/test/labels/')
\ No newline at end of file
diff --git a/utils/postprocessing/hole_filling.py b/utils/postprocessing/hole_filling.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea694dda9f74fa63f8eb85ab144262ee3fe3a46d
--- /dev/null
+++ b/utils/postprocessing/hole_filling.py
@@ -0,0 +1,32 @@
+import cv2
+import numpy as np
+from scipy.ndimage.measurements import label
+
+
+def fill_holes(img, threshold, rate):
+ binary_img = np.where(img > threshold, 0, 1) #reversed image
+ structure = np.ones((3, 3, 3), dtype=np.int)
+ labeled, ncomponents = label(binary_img, structure)
+ # print(labeled.shape, ncomponents)
+ count_list = []
+ #count
+ for pixel_val in range(ncomponents):
+ count = 0
+ for y in range(labeled.shape[1]):
+ for x in range(labeled.shape[0]):
+ if labeled[x][y][0] == pixel_val + 1:
+ count += 1
+ count_list.append(count)
+ # print(count_list)
+
+ for i in range(len(count_list)):
+ # print(i)
+ if sum(count_list) != 0:
+ if count_list[i] / sum(count_list) < rate:
+ for y in range(labeled.shape[1]):
+ for x in range(labeled.shape[0]):
+ if labeled[x][y][0] == i + 1:
+ labeled[x][y] = [0,0,0]
+ labeled = np.where(labeled < 1, 1, 0)
+ labeled *= 255
+ return labeled
\ No newline at end of file
diff --git a/utils/postprocessing/remove_small_noise.py b/utils/postprocessing/remove_small_noise.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d32a8b6dc61960b92937a1d014f099082acc4b0
--- /dev/null
+++ b/utils/postprocessing/remove_small_noise.py
@@ -0,0 +1,31 @@
+import cv2
+import numpy as np
+from scipy.ndimage.measurements import label
+
+
+def remove_small_areas(img, threshold, rate):
+ structure = np.ones((3, 3, 3), dtype=np.int)
+ labeled, ncomponents = label(img, structure)
+ # print(labeled.shape, ncomponents)
+ count_list = []
+ # count
+ for pixel_val in range(ncomponents):
+ count = 0
+ for y in range(labeled.shape[1]):
+ for x in range(labeled.shape[0]):
+ if labeled[x][y][0] == pixel_val + 1:
+ count += 1
+ count_list.append(count)
+ # print(count_list)
+
+ for i in range(len(count_list)):
+ # print(i)
+ if sum(count_list) != 0:
+ if count_list[i] / sum(count_list) < rate:
+ for y in range(labeled.shape[1]):
+ for x in range(labeled.shape[0]):
+ if labeled[x][y][0] == i + 1:
+ labeled[x][y] = [0, 0, 0]
+ labeled = np.where(labeled < 1, 0, 1)
+ labeled *= 255
+ return labeled
\ No newline at end of file
diff --git a/utils/postprocessing/threshold.py b/utils/postprocessing/threshold.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c8adb7d1c8b8370ba73b639424919d77492537b
--- /dev/null
+++ b/utils/postprocessing/threshold.py
@@ -0,0 +1,7 @@
+# ------------------------------------------------------------ #
+#
+# file : postprocessing/threshold.py
+# author : CM
+# Segment image with threshold
+#
+# ------------------------------------------------------------ #
diff --git a/utils/preprocessing/normalisation.py b/utils/preprocessing/normalisation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6849fec43c870a920ea4882a86f968d1afa8f714
--- /dev/null
+++ b/utils/preprocessing/normalisation.py
@@ -0,0 +1,34 @@
+# ------------------------------------------------------------ #
+#
+# file : preprocessing/normalisation.py
+# author : CM
+#
+# ------------------------------------------------------------ #
+import numpy as np
+
+# Rescaling (min-max normalization)
+def linear_intensity_normalization(loaded_dataset):
+ loaded_dataset = (loaded_dataset / loaded_dataset.max())
+ return loaded_dataset
+
+# Preprocess dataset with intensity normalisation
+# (zero mean and unit variance)
+def standardization_intensity_normalization(dataset, dtype):
+ mean = dataset.mean()
+ std = dataset.std()
+ return ((dataset - mean) / std).astype(dtype)
+
+# Intensities normalized to the range [0, 1]
+def intensityNormalisationFeatureScaling(dataset, dtype):
+ max = dataset.max()
+ min = dataset.min()
+
+ return ((dataset - min) / (max - min)).astype(dtype)
+
+# Intensity max clipping with c "max value"
+def intensityMaxClipping(dataset, c, dtype):
+ return np.clip(a=dataset, a_min=0, a_max=c).astype(dtype)
+
+# Intensity projection
+def intensityProjection(dataset, p, dtype):
+ return (dataset ** p).astype(dtype)
\ No newline at end of file
diff --git a/utils/preprocessing/threshold.py b/utils/preprocessing/threshold.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2fcb0e22a7d7af0b88c69f29df6cfe640b762c1
--- /dev/null
+++ b/utils/preprocessing/threshold.py
@@ -0,0 +1,109 @@
+# ------------------------------------------------------------ #
+#
+# file : preprocessing/threshold.py
+# author : CM
+# Preprocess function for Bullitt dataset
+#
+# ------------------------------------------------------------ #
+import os
+import sys
+import numpy as np
+import nibabel as nib
+from utils.io.write import npToNii
+from utils.config.read import readConfig
+
+# Get the threshold for preprocessing
+def getThreshold(dataset_mra, dataset_gd):
+ threshold = dataset_mra.max()
+
+ for i in range(0,len(dataset_mra)):
+ mra = dataset_mra[i]
+ gd = dataset_gd[i]
+
+ for x in range(0, mra.shape[0]):
+ for y in range(0, mra.shape[1]):
+ for z in range(0, mra.shape[2]):
+ if(gd[x,y,z] == 1 and mra[x,y,z] < threshold):
+ threshold = mra[x,y,z]
+
+ return threshold
+
+# Apply threshold to an image
+def thresholding(data, threshold):
+ output = np.copy(data)
+ for x in range(0, data.shape[0]):
+ for y in range(0, data.shape[1]):
+ for z in range(0, data.shape[2]):
+ if data[x,y,z] > threshold:
+ output[x,y,z] = data[x,y,z]
+ else:
+ output[x,y,z] = 0
+ return output
+
+config_filename = sys.argv[1]
+if(not os.path.isfile(config_filename)):
+ sys.exit(1)
+
+config = readConfig(config_filename)
+
+output_folder = sys.argv[2]
+if(not os.path.isdir(output_folder)):
+ sys.exit(1)
+
+print("Loading training dataset")
+
+train_mra_dataset = np.empty((30, config["image_size_x"], config["image_size_y"], config["image_size_z"]))
+i = 0
+files = os.listdir(config["dataset_train_mra_path"])
+files.sort()
+
+for filename in files:
+ if(i>=30):
+ break
+ print(filename)
+ train_mra_dataset[i, :, :, :] = nib.load(os.path.join(config["dataset_train_mra_path"], filename)).get_data()
+ i = i + 1
+
+
+train_gd_dataset = np.empty((30, config["image_size_x"], config["image_size_y"], config["image_size_z"]))
+i = 0
+files = os.listdir(config["dataset_train_gd_path"])
+files.sort()
+
+for filename in files:
+ if(i>=30):
+ break
+ print(filename)
+ train_gd_dataset[i, :, :, :] = nib.load(os.path.join(config["dataset_train_gd_path"], filename)).get_data()
+ i = i + 1
+
+print("Compute threshold")
+threshold = getThreshold(train_mra_dataset, train_gd_dataset)
+
+train_mra_dataset = None
+train_gd_dataset = None
+
+print("Apply preprocessing to test image")
+files = os.listdir(config["dataset_test_mra_path"])
+files.sort()
+
+for filename in files:
+ print(filename)
+ data = nib.load(os.path.join(config["dataset_test_mra_path"], filename)).get_data()
+ print(np.average(data))
+ preprocessed = thresholding(data, threshold)
+ print(np.average(preprocessed))
+ npToNii(preprocessed,os.path.join(output_folder+'/test_Images', 'pre_'+filename))
+
+
+print("Apply threshold to train image : ", threshold)
+files = os.listdir(config["dataset_train_mra_path"])
+files.sort()
+
+for filename in files:
+ print(filename)
+ data = nib.load(os.path.join(config["dataset_train_mra_path"], filename)).get_data()
+ print(np.average(data))
+ preprocessed = thresholding(data, threshold)
+ print(np.average(preprocessed))
+ npToNii(preprocessed,os.path.join(output_folder+'/train_Images', 'pre_'+filename))
\ No newline at end of file