diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c2f49f56d3cbbe977a841c2245820260cbc557f7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +training_history/2019-12-19[[:space:]]01%3A53%3A15.480800.hdf5 filter=lfs diff=lfs merge=lfs -text +training_history/2025-08-07_16-25-27.hdf5 filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..92b8ee37a84afe0071a8549ba8a90acc31884494 --- /dev/null +++ b/app.py @@ -0,0 +1,1786 @@ +import glob +import gradio as gr +import matplotlib +import numpy as np +from PIL import Image +import torch +import tempfile +from gradio_imageslider import ImageSlider +import plotly.graph_objects as go +import plotly.express as px +import open3d as o3d +from depth_anything_v2.dpt import DepthAnythingV2 +import os +import tensorflow as tf +from tensorflow.keras.models import load_model + +# Classification imports +from transformers import AutoImageProcessor, AutoModelForImageClassification +import google.generativeai as genai + +import gdown +import spaces +import cv2 + + +# Import actual segmentation model components +from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling +from utils.learning.metrics import dice_coef, precision, recall +from utils.io.data import normalize + +# --- Classification Model Setup --- +# Load classification model and processor +classification_processor = AutoImageProcessor.from_pretrained("Hemg/Wound-classification") +classification_model = AutoModelForImageClassification.from_pretrained("Hemg/Wound-classification") + +# Configure Gemini AI +try: + # Try to get API key from Hugging Face secrets + gemini_api_key = os.getenv("GOOGLE_API_KEY") + if not gemini_api_key: + raise ValueError("GEMINI_API_KEY not found in environment variables") + + genai.configure(api_key=gemini_api_key) + gemini_model = genai.GenerativeModel("gemini-2.5-pro") + print("✅ Gemini AI configured successfully with API key from secrets") +except Exception as e: + print(f"❌ Error configuring Gemini AI: {e}") + print("Please make sure GEMINI_API_KEY is set in your Hugging Face Space secrets") + gemini_model = None + +# --- Classification Functions --- +def analyze_wound_with_gemini(image, predicted_label): + """ + Analyze wound image using Gemini AI with classification context + + Args: + image: PIL Image + predicted_label: The predicted wound type from classification model + + Returns: + str: Gemini AI analysis + """ + if image is None: + return "No image provided for analysis." + + if gemini_model is None: + return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." + + try: + # Ensure image is in RGB format + if image.mode != 'RGB': + image = image.convert('RGB') + + # Create prompt that includes the classification result + prompt = f"""You are assisting in a medical education and research task. + +Based on the wound classification model, this image has been identified as: {predicted_label} + +Please provide an educational analysis of this wound image focusing on: +1. Visible characteristics of the wound (size, color, texture, edges, surrounding tissue) +2. Educational explanation about this type of wound based on the classification: {predicted_label} +3. General wound healing stages if applicable +4. Key features that are typically associated with this wound type + +Important guidelines: +- This is for educational and research purposes only +- Do not provide medical advice or diagnosis +- Keep the analysis objective and educational +- Focus on visible features and general wound characteristics +- Do not recommend treatments or medical interventions + +Please provide a comprehensive educational analysis.""" + + response = gemini_model.generate_content([prompt, image]) + return response.text + + except Exception as e: + return f"Error analyzing image with Gemini: {str(e)}" + +def analyze_wound_depth_with_gemini(image, depth_map, depth_stats): + """ + Analyze wound depth and severity using Gemini AI with depth analysis context + + Args: + image: Original wound image (PIL Image or numpy array) + depth_map: Depth map (numpy array) + depth_stats: Dictionary containing depth analysis statistics + + Returns: + str: Gemini AI medical assessment based on depth analysis + """ + if image is None or depth_map is None: + return "No image or depth map provided for analysis." + + if gemini_model is None: + return "Gemini AI is not available. Please check that GEMINI_API_KEY is properly configured in your Hugging Face Space secrets." + + try: + # Convert numpy array to PIL Image if needed + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + + # Ensure image is in RGB format + if image.mode != 'RGB': + image = image.convert('RGB') + + # Convert depth map to PIL Image for Gemini + if isinstance(depth_map, np.ndarray): + # Normalize depth map for visualization + norm_depth = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0 + depth_image = Image.fromarray(norm_depth.astype(np.uint8)) + else: + depth_image = depth_map + + # Create detailed prompt with depth statistics + prompt = f"""You are a medical AI assistant specializing in wound assessment. Analyze this wound using both the original image and depth map data. + +DEPTH ANALYSIS DATA PROVIDED: +- Total Wound Area: {depth_stats['total_area_cm2']:.2f} cm² +- Mean Depth: {depth_stats['mean_depth_mm']:.1f} mm +- Maximum Depth: {depth_stats['max_depth_mm']:.1f} mm +- Depth Standard Deviation: {depth_stats['depth_std_mm']:.1f} mm +- Wound Volume: {depth_stats['wound_volume_cm3']:.2f} cm³ +- Deep Tissue Involvement: {depth_stats['deep_ratio']*100:.1f}% +- Analysis Quality: {depth_stats['analysis_quality']} +- Depth Consistency: {depth_stats['depth_consistency']} + +TISSUE DEPTH DISTRIBUTION: +- Superficial Areas (0-2mm): {depth_stats['superficial_area_cm2']:.2f} cm² +- Partial Thickness (2-4mm): {depth_stats['partial_thickness_area_cm2']:.2f} cm² +- Full Thickness (4-6mm): {depth_stats['full_thickness_area_cm2']:.2f} cm² +- Deep Areas (>6mm): {depth_stats['deep_area_cm2']:.2f} cm² + +STATISTICAL DEPTH ANALYSIS: +- 25th Percentile Depth: {depth_stats['depth_percentiles']['25']:.1f} mm +- Median Depth: {depth_stats['depth_percentiles']['50']:.1f} mm +- 75th Percentile Depth: {depth_stats['depth_percentiles']['75']:.1f} mm + +Please provide a comprehensive medical assessment focusing on: + +1. **WOUND CHARACTERISTICS ANALYSIS** + - Visible wound features from the original image + - Correlation between visual appearance and depth measurements + - Tissue quality assessment based on color, texture, and depth data + +2. **DEPTH-BASED SEVERITY ASSESSMENT** + - Clinical significance of the measured depths + - Tissue layer involvement based on depth measurements + - Risk assessment based on deep tissue involvement percentage + +3. **HEALING PROGNOSIS** + - Expected healing timeline based on depth and area measurements + - Factors that may affect healing based on depth distribution + - Complexity assessment based on wound volume and depth variation + +4. **CLINICAL CONSIDERATIONS** + - Significance of depth consistency/inconsistency + - Areas of particular concern based on depth analysis + - Educational insights about this type of wound presentation + +5. **MEASUREMENT INTERPRETATION** + - Clinical relevance of the statistical depth measurements + - What the depth distribution tells us about wound progression + - Comparison to typical wound depth classifications + +IMPORTANT GUIDELINES: +- This is for educational and research purposes only +- Do not provide specific medical advice or treatment recommendations +- Focus on objective analysis of the provided measurements +- Correlate visual findings with quantitative depth data +- Maintain educational and clinical terminology +- Emphasize the relationship between depth measurements and clinical significance + +Provide a detailed, structured medical assessment that integrates both visual and quantitative depth analysis.""" + + # Send both images to Gemini for analysis + response = gemini_model.generate_content([prompt, image, depth_image]) + return response.text + + except Exception as e: + return f"Error analyzing wound with Gemini AI: {str(e)}" + +def classify_wound(image): + """ + Classify wound type from uploaded image + + Args: + image: PIL Image or numpy array + + Returns: + dict: Classification results with confidence scores + """ + if image is None: + return "Please upload an image" + + # Convert to PIL Image if needed + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + + # Ensure image is in RGB format + if image.mode != 'RGB': + image = image.convert('RGB') + + try: + # Process the image + inputs = classification_processor(images=image, return_tensors="pt") + + # Get model predictions + with torch.no_grad(): + outputs = classification_model(**inputs) + predictions = torch.nn.functional.softmax(outputs.logits[0], dim=-1) + + # Get the predicted class labels and confidence scores + confidence_scores = predictions.numpy() + + # Create results dictionary + results = {} + for i, score in enumerate(confidence_scores): + # Get class name from model config + class_name = classification_model.config.id2label[i] if hasattr(classification_model.config, 'id2label') else f"Class {i}" + results[class_name] = float(score) + + return results + + except Exception as e: + return f"Error processing image: {str(e)}" + +def classify_and_analyze_wound(image): + """ + Combined function to classify wound and get Gemini analysis + + Args: + image: PIL Image or numpy array + + Returns: + tuple: (classification_results, gemini_analysis) + """ + if image is None: + return "Please upload an image", "Please upload an image for analysis" + + # Get classification results + classification_results = classify_wound(image) + + # Get the top predicted label for Gemini analysis + if isinstance(classification_results, dict) and classification_results: + # Get the label with highest confidence + top_label = max(classification_results.items(), key=lambda x: x[1])[0] + + # Get Gemini analysis + gemini_analysis = analyze_wound_with_gemini(image, top_label) + else: + top_label = "Unknown" + gemini_analysis = "Unable to analyze due to classification error" + + return classification_results, gemini_analysis + +def format_gemini_analysis(analysis): + """Format Gemini analysis as properly structured HTML""" + if not analysis or "Error" in analysis: + return f""" +
+

Analysis Error

+

{analysis}

+
+ """ + + # Parse the markdown-style response and convert to HTML + formatted_analysis = parse_markdown_to_html(analysis) + + return f""" +
+

+ Initial Wound Analysis +

+
+ {formatted_analysis} +
+
+ """ + +def format_gemini_depth_analysis(analysis): + """Format Gemini depth analysis as properly structured HTML for medical assessment""" + if not analysis or "Error" in analysis: + return f""" +
+
+ ❌ AI Analysis Error +
+
+ {analysis} +
+
+ """ + + # Parse the markdown-style response and convert to HTML + formatted_analysis = parse_markdown_to_html(analysis) + + return f""" +
+
+ 🤖 AI-Powered Medical Assessment +
+
+ {formatted_analysis} +
+
+ """ + +def parse_markdown_to_html(text): + """Convert markdown-style text to HTML""" + import re + + # Replace markdown headers + text = re.sub(r'^### \*\*(.*?)\*\*$', r'

\1

', text, flags=re.MULTILINE) + text = re.sub(r'^#### \*\*(.*?)\*\*$', r'
\1
', text, flags=re.MULTILINE) + text = re.sub(r'^### (.*?)$', r'

\1

', text, flags=re.MULTILINE) + text = re.sub(r'^#### (.*?)$', r'
\1
', text, flags=re.MULTILINE) + + # Replace bold text + text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) + + # Replace italic text + text = re.sub(r'\*(.*?)\*', r'\1', text) + + # Replace bullet points + text = re.sub(r'^\* (.*?)$', r'
  • \1
  • ', text, flags=re.MULTILINE) + text = re.sub(r'^ \* (.*?)$', r'
  • \1
  • ', text, flags=re.MULTILINE) + + # Wrap consecutive list items in ul tags + text = re.sub(r'((?:\s*)*)', r'', text, flags=re.DOTALL) + + # Replace numbered lists + text = re.sub(r'^(\d+)\.\s+(.*?)$', r'
    \1. \2
    ', text, flags=re.MULTILINE) + + # Convert paragraphs (double newlines) + paragraphs = text.split('\n\n') + formatted_paragraphs = [] + + for para in paragraphs: + para = para.strip() + if para: + # Skip if it's already wrapped in HTML tags + if not (para.startswith('<') or para.endswith('>')): + para = f'

    {para}

    ' + formatted_paragraphs.append(para) + + return '\n'.join(formatted_paragraphs) + +def combined_analysis(image): + """Combined function for UI that returns both outputs""" + classification, gemini_analysis = classify_and_analyze_wound(image) + formatted_analysis = format_gemini_analysis(gemini_analysis) + return classification, formatted_analysis + + + + + +# Define path and file ID +checkpoint_dir = "checkpoints" +os.makedirs(checkpoint_dir, exist_ok=True) + +model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") +gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" + +# Download if not already present +if not os.path.exists(model_file): + print("Downloading model from Google Drive...") + gdown.download(gdrive_url, model_file, quiet=False) + +# --- TensorFlow: Check GPU Availability --- +gpus = tf.config.list_physical_devices('GPU') +if gpus: + print("TensorFlow is using GPU") +else: + print("TensorFlow is using CPU") + + + +# --- Load Actual Wound Segmentation Model --- +class WoundSegmentationModel: + def __init__(self): + self.input_dim_x = 224 + self.input_dim_y = 224 + self.model = None + self.load_model() + + def load_model(self): + """Load the trained wound segmentation model""" + try: + # Try to load the most recent model + weight_file_name = '2025-08-07_16-25-27.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e: + print(f"Error loading segmentation model: {e}") + # Fallback to the older model + try: + weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e2: + print(f"Error loading fallback segmentation model: {e2}") + self.model = None + + def preprocess_image(self, image): + """Preprocess the uploaded image for model input""" + if image is None: + return None + + # Convert to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + # Convert BGR to RGB if needed + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Resize to model input size + image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) + + # Normalize the image + image = image.astype(np.float32) / 255.0 + + # Add batch dimension + image = np.expand_dims(image, axis=0) + + return image + + def postprocess_prediction(self, prediction): + """Postprocess the model prediction""" + # Remove batch dimension + prediction = prediction[0] + + # Apply threshold to get binary mask + threshold = 0.5 + binary_mask = (prediction > threshold).astype(np.uint8) * 255 + + return binary_mask + + def segment_wound(self, input_image): + """Main function to segment wound from uploaded image""" + if self.model is None: + return None, "Error: Segmentation model not loaded. Please check the model files." + + if input_image is None: + return None, "Please upload an image." + + try: + # Preprocess the image + processed_image = self.preprocess_image(input_image) + + if processed_image is None: + return None, "Error processing image." + + # Make prediction + prediction = self.model.predict(processed_image, verbose=0) + + # Postprocess the prediction + segmented_mask = self.postprocess_prediction(prediction) + + return segmented_mask, "Segmentation completed successfully!" + + except Exception as e: + return None, f"Error during segmentation: {str(e)}" + +# Initialize the segmentation model +segmentation_model = WoundSegmentationModel() + +# --- PyTorch: Set Device and Load Depth Model --- +map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") +print(f"Using PyTorch device: {map_device}") + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} +encoder = 'vitl' +depth_model = DepthAnythingV2(**model_configs[encoder]) +state_dict = torch.load( + f'checkpoints/depth_anything_v2_{encoder}.pth', + map_location=map_device +) +depth_model.load_state_dict(state_dict) +depth_model = depth_model.to(map_device).eval() + + +# --- Custom CSS for unified dark theme --- +css = """ +.gradio-container { + font-family: 'Segoe UI', sans-serif; + background-color: #121212; + color: #ffffff; + padding: 20px; +} +.gr-button { + background-color: #2c3e50; + color: white; + border-radius: 10px; +} +.gr-button:hover { + background-color: #34495e; +} +.gr-html, .gr-html div { + white-space: normal !important; + overflow: visible !important; + text-overflow: unset !important; + word-break: break-word !important; +} +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 80vh; +} +#img-display-output { + max-height: 80vh; +} +#download { + height: 62px; +} +h1 { + text-align: center; + font-size: 3rem; + font-weight: bold; + margin: 2rem 0; + color: #ffffff; +} +h2 { + color: #ffffff; + text-align: center; + margin: 1rem 0; +} +.gr-tabs { + background-color: #1e1e1e; + border-radius: 10px; + padding: 10px; +} +.gr-tab-nav { + background-color: #2c3e50; + border-radius: 8px; +} +.gr-tab-nav button { + color: #ffffff !important; +} +.gr-tab-nav button.selected { + background-color: #34495e !important; +} +/* Card styling for consistent heights */ +.wound-card { + min-height: 200px !important; + display: flex !important; + flex-direction: column !important; + justify-content: space-between !important; +} +.wound-card-content { + flex-grow: 1 !important; + display: flex !important; + flex-direction: column !important; + justify-content: center !important; +} +/* Loading animation */ +.loading-spinner { + display: inline-block; + width: 20px; + height: 20px; + border: 3px solid #f3f3f3; + border-top: 3px solid #3498db; + border-radius: 50%; + animation: spin 1s linear infinite; +} +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} +""" + + + + + +# --- Enhanced Wound Severity Estimation Functions --- + +def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): + """ + Enhanced depth analysis with proper calibration and medical standards + Based on wound depth classification standards: + - Superficial: 0-2mm (epidermis only) + - Partial thickness: 2-4mm (epidermis + partial dermis) + - Full thickness: 4-6mm (epidermis + full dermis) + - Deep: >6mm (involving subcutaneous tissue) + """ + # Convert pixel spacing to mm + pixel_spacing_mm = float(pixel_spacing_mm) + + # Calculate pixel area in cm² + pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 + + # Extract wound region (binary mask) + wound_mask = (mask > 127).astype(np.uint8) + + # Apply morphological operations to clean the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel) + + # Get depth values only for wound region + wound_depths = depth_map[wound_mask > 0] + + if len(wound_depths) == 0: + return { + 'total_area_cm2': 0, + 'superficial_area_cm2': 0, + 'partial_thickness_area_cm2': 0, + 'full_thickness_area_cm2': 0, + 'deep_area_cm2': 0, + 'mean_depth_mm': 0, + 'max_depth_mm': 0, + 'depth_std_mm': 0, + 'deep_ratio': 0, + 'wound_volume_cm3': 0, + 'depth_percentiles': {'25': 0, '50': 0, '75': 0} + } + + # Normalize depth relative to nearest point in wound area + normalized_depth_map, nearest_point_coords, max_relative_depth = normalize_depth_relative_to_nearest_point(depth_map, wound_mask) + + # Calibrate the normalized depth map for more accurate measurements + calibrated_depth_map = calibrate_depth_map(normalized_depth_map, reference_depth_mm=depth_calibration_mm) + + # Get calibrated depth values for wound region + wound_depths_mm = calibrated_depth_map[wound_mask > 0] + + # Medical depth classification + superficial_mask = wound_depths_mm < 2.0 + partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0) + full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0) + deep_mask = wound_depths_mm >= 6.0 + + # Calculate areas + total_pixels = np.sum(wound_mask > 0) + total_area_cm2 = total_pixels * pixel_area_cm2 + + superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2 + partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2 + full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2 + deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2 + + # Calculate depth statistics + mean_depth_mm = np.mean(wound_depths_mm) + max_depth_mm = np.max(wound_depths_mm) + depth_std_mm = np.std(wound_depths_mm) + + # Calculate depth percentiles + depth_percentiles = { + '25': np.percentile(wound_depths_mm, 25), + '50': np.percentile(wound_depths_mm, 50), + '75': np.percentile(wound_depths_mm, 75) + } + + # Calculate depth distribution statistics + depth_distribution = { + 'shallow_ratio': np.sum(wound_depths_mm < 2.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, + 'moderate_ratio': np.sum((wound_depths_mm >= 2.0) & (wound_depths_mm < 5.0)) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0, + 'deep_ratio': np.sum(wound_depths_mm >= 5.0) / len(wound_depths_mm) if len(wound_depths_mm) > 0 else 0 + } + + # Calculate wound volume (approximate) + # Volume = area * average depth + wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0) + + # Deep tissue ratio + deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0 + + # Calculate analysis quality metrics + wound_pixel_count = len(wound_depths_mm) + analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low" + + # Calculate depth consistency (lower std dev = more consistent) + depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low" + + return { + 'total_area_cm2': total_area_cm2, + 'superficial_area_cm2': superficial_area_cm2, + 'partial_thickness_area_cm2': partial_thickness_area_cm2, + 'full_thickness_area_cm2': full_thickness_area_cm2, + 'deep_area_cm2': deep_area_cm2, + 'mean_depth_mm': mean_depth_mm, + 'max_depth_mm': max_depth_mm, + 'depth_std_mm': depth_std_mm, + 'deep_ratio': deep_ratio, + 'wound_volume_cm3': wound_volume_cm3, + 'depth_percentiles': depth_percentiles, + 'depth_distribution': depth_distribution, + 'analysis_quality': analysis_quality, + 'depth_consistency': depth_consistency, + 'wound_pixel_count': wound_pixel_count, + 'nearest_point_coords': nearest_point_coords, + 'max_relative_depth': max_relative_depth, + 'normalized_depth_map': normalized_depth_map + } + +def classify_wound_severity_by_enhanced_metrics(depth_stats): + """ + Enhanced wound severity classification based on medical standards + Uses multiple criteria: depth, area, volume, and tissue involvement + """ + if depth_stats['total_area_cm2'] == 0: + return "Unknown" + + # Extract key metrics + total_area = depth_stats['total_area_cm2'] + deep_area = depth_stats['deep_area_cm2'] + full_thickness_area = depth_stats['full_thickness_area_cm2'] + mean_depth = depth_stats['mean_depth_mm'] + max_depth = depth_stats['max_depth_mm'] + wound_volume = depth_stats['wound_volume_cm3'] + deep_ratio = depth_stats['deep_ratio'] + + # Medical severity classification criteria + severity_score = 0 + + # Criterion 1: Maximum depth + if max_depth >= 10.0: + severity_score += 3 # Very severe + elif max_depth >= 6.0: + severity_score += 2 # Severe + elif max_depth >= 4.0: + severity_score += 1 # Moderate + + # Criterion 2: Mean depth + if mean_depth >= 5.0: + severity_score += 2 + elif mean_depth >= 3.0: + severity_score += 1 + + # Criterion 3: Deep tissue involvement ratio + if deep_ratio >= 0.5: + severity_score += 3 # More than 50% deep tissue + elif deep_ratio >= 0.25: + severity_score += 2 # 25-50% deep tissue + elif deep_ratio >= 0.1: + severity_score += 1 # 10-25% deep tissue + + # Criterion 4: Total wound area + if total_area >= 10.0: + severity_score += 2 # Large wound (>10 cm²) + elif total_area >= 5.0: + severity_score += 1 # Medium wound (5-10 cm²) + + # Criterion 5: Wound volume + if wound_volume >= 5.0: + severity_score += 2 # High volume + elif wound_volume >= 2.0: + severity_score += 1 # Medium volume + + # Determine severity based on total score + if severity_score >= 8: + return "Very Severe" + elif severity_score >= 6: + return "Severe" + elif severity_score >= 4: + return "Moderate" + elif severity_score >= 2: + return "Mild" + else: + return "Superficial" + + + + + +def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): + """Enhanced wound severity analysis based on depth measurements""" + if image is None or depth_map is None or wound_mask is None: + return "❌ Please upload image, depth map, and wound mask." + + # Convert wound mask to grayscale if needed + if len(wound_mask.shape) == 3: + wound_mask = np.mean(wound_mask, axis=2) + + # Ensure depth map and mask have same dimensions + if depth_map.shape[:2] != wound_mask.shape[:2]: + # Resize mask to match depth map + from PIL import Image + mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) + mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) + wound_mask = np.array(mask_pil) + + # Compute enhanced statistics with relative depth normalization + stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm) + + # Get severity based on enhanced metrics + severity_level = classify_wound_severity_by_enhanced_metrics(stats) + severity_description = get_enhanced_severity_description(severity_level) + + # Get Gemini AI analysis based on depth data + gemini_analysis = analyze_wound_depth_with_gemini(image, depth_map, stats) + + # Format Gemini analysis for display + formatted_gemini_analysis = format_gemini_depth_analysis(gemini_analysis) + + # Create depth analysis visualization + depth_visualization = create_depth_analysis_visualization( + stats['normalized_depth_map'], wound_mask, + stats['nearest_point_coords'], stats['max_relative_depth'] + ) + + # Enhanced severity color coding + severity_color = { + "Superficial": "#4CAF50", # Green + "Mild": "#8BC34A", # Light Green + "Moderate": "#FF9800", # Orange + "Severe": "#F44336", # Red + "Very Severe": "#9C27B0" # Purple + }.get(severity_level, "#9E9E9E") # Gray for unknown + + # Create comprehensive medical report + report = f""" +
    +
    + 🩹 Enhanced Wound Severity Analysis +
    + +
    +
    + 📊 Depth & Quality Analysis +
    +
    +
    +
    � Basic Measurements
    +
    �📏 Mean Relative Depth: {stats['mean_depth_mm']:.1f} mm
    +
    📐 Max Relative Depth: {stats['max_depth_mm']:.1f} mm
    +
    📊 Depth Std Dev: {stats['depth_std_mm']:.1f} mm
    +
    📦 Wound Volume: {stats['wound_volume_cm3']:.2f} cm³
    +
    🔥 Deep Tissue Ratio: {stats['deep_ratio']*100:.1f}%
    +
    +
    +
    📈 Statistical Analysis
    +
    25th Percentile: {stats['depth_percentiles']['25']:.1f} mm
    +
    📊 Median (50th): {stats['depth_percentiles']['50']:.1f} mm
    +
    📊 75th Percentile: {stats['depth_percentiles']['75']:.1f} mm
    +
    📊 Shallow Areas: {stats['depth_distribution']['shallow_ratio']*100:.1f}%
    +
    📊 Moderate Areas: {stats['depth_distribution']['moderate_ratio']*100:.1f}%
    +
    +
    +
    🔍 Quality Metrics
    +
    🔍 Analysis Quality: {stats['analysis_quality']}
    +
    📏 Depth Consistency: {stats['depth_consistency']}
    +
    📊 Data Points: {stats['wound_pixel_count']:,}
    +
    📊 Deep Areas: {stats['depth_distribution']['deep_ratio']*100:.1f}%
    +
    🎯 Reference Point: Nearest to camera
    +
    +
    +
    + +
    +
    + 📊 Medical Assessment Based on Depth Analysis +
    + {formatted_gemini_analysis} +
    +
    + """ + + return report + +def normalize_depth_relative_to_nearest_point(depth_map, wound_mask): + """ + Normalize depth map relative to the nearest point in the wound area + This assumes a top-down camera perspective where the closest point to camera = 0 depth + + Args: + depth_map: Raw depth map + wound_mask: Binary mask of wound region + + Returns: + normalized_depth: Depth values relative to nearest point (0 = nearest, positive = deeper) + nearest_point_coords: Coordinates of the nearest point + max_relative_depth: Maximum relative depth in the wound + """ + if depth_map is None or wound_mask is None: + return depth_map, None, 0 + + # Convert mask to binary + binary_mask = (wound_mask > 127).astype(np.uint8) + + # Find wound region coordinates + wound_coords = np.where(binary_mask > 0) + + if len(wound_coords[0]) == 0: + return depth_map, None, 0 + + # Get depth values only for wound region + wound_depths = depth_map[wound_coords] + + # Find the nearest point (minimum depth value in wound region) + nearest_depth = np.min(wound_depths) + nearest_indices = np.where(wound_depths == nearest_depth) + + # Get coordinates of the nearest point(s) + nearest_point_coords = (wound_coords[0][nearest_indices[0][0]], + wound_coords[1][nearest_indices[0][0]]) + + # Create normalized depth map (relative to nearest point) + normalized_depth = depth_map.copy() + normalized_depth = normalized_depth - nearest_depth + + # Ensure all values are non-negative (nearest point = 0, others = positive) + normalized_depth = np.maximum(normalized_depth, 0) + + # Calculate maximum relative depth in wound region + wound_normalized_depths = normalized_depth[wound_coords] + max_relative_depth = np.max(wound_normalized_depths) + + return normalized_depth, nearest_point_coords, max_relative_depth + +def calibrate_depth_map(depth_map, reference_depth_mm=10.0): + """ + Calibrate depth map to real-world measurements using reference depth + This helps convert normalized depth values to actual millimeters + """ + if depth_map is None: + return depth_map + + # Find the maximum depth value in the depth map + max_depth_value = np.max(depth_map) + min_depth_value = np.min(depth_map) + + if max_depth_value == min_depth_value: + return depth_map + + # Apply calibration to convert to millimeters + # Assuming the maximum depth in the map corresponds to reference_depth_mm + calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm + + return calibrated_depth + +def create_depth_analysis_visualization(depth_map, wound_mask, nearest_point_coords, max_relative_depth): + """ + Create a visualization showing the depth analysis with nearest point and deepest point highlighted + """ + if depth_map is None or wound_mask is None: + return None + + # Create a copy of the depth map for visualization + vis_depth = depth_map.copy() + + # Apply colormap for better visualization + normalized_depth = (vis_depth - np.min(vis_depth)) / (np.max(vis_depth) - np.min(vis_depth)) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(normalized_depth)[:, :, :3] * 255).astype(np.uint8) + + # Convert to RGB if grayscale + if len(colored_depth.shape) == 3 and colored_depth.shape[2] == 1: + colored_depth = cv2.cvtColor(colored_depth, cv2.COLOR_GRAY2RGB) + + # Highlight the nearest point (reference point) with a red circle + if nearest_point_coords is not None: + y, x = nearest_point_coords + cv2.circle(colored_depth, (x, y), 10, (255, 0, 0), 2) # Red circle for nearest point + cv2.putText(colored_depth, "REF", (x+15, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) + + # Find and highlight the deepest point + binary_mask = (wound_mask > 127).astype(np.uint8) + wound_coords = np.where(binary_mask > 0) + + if len(wound_coords[0]) > 0: + # Get depth values for wound region + wound_depths = vis_depth[wound_coords] + max_depth_idx = np.argmax(wound_depths) + deepest_point_coords = (wound_coords[0][max_depth_idx], wound_coords[1][max_depth_idx]) + + # Highlight the deepest point with a blue circle + y, x = deepest_point_coords + cv2.circle(colored_depth, (x, y), 12, (0, 0, 255), 3) # Blue circle for deepest point + cv2.putText(colored_depth, "DEEP", (x+15, y+5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + # Overlay wound mask outline + contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(colored_depth, contours, -1, (0, 255, 0), 2) # Green outline for wound boundary + + return colored_depth + +def get_enhanced_severity_description(severity): + """Get comprehensive medical description for severity level""" + descriptions = { + "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.", + "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.", + "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.", + "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.", + "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.", + "Unknown": "Unable to determine severity due to insufficient data or poor image quality." + } + return descriptions.get(severity, "Severity assessment unavailable.") + +def create_sample_wound_mask(image_shape, center=None, radius=50): + """Create a sample circular wound mask for testing""" + if center is None: + center = (image_shape[1] // 2, image_shape[0] // 2) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + y, x = np.ogrid[:image_shape[0], :image_shape[1]] + + # Create circular mask + dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) + mask[dist_from_center <= radius] = 255 + + return mask + +def create_realistic_wound_mask(image_shape, method='elliptical'): + """Create a more realistic wound mask with irregular shapes""" + h, w = image_shape[:2] + mask = np.zeros((h, w), dtype=np.uint8) + + if method == 'elliptical': + # Create elliptical wound mask + center = (w // 2, h // 2) + radius_x = min(w, h) // 3 + radius_y = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + # Add some irregularity to make it more realistic + ellipse = ((x - center[0])**2 / (radius_x**2) + + (y - center[1])**2 / (radius_y**2)) <= 1 + + # Add some noise and irregularity + noise = np.random.random((h, w)) > 0.8 + mask = (ellipse | noise).astype(np.uint8) * 255 + + elif method == 'irregular': + # Create irregular wound mask + center = (w // 2, h // 2) + radius = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius + + # Add irregular extensions + extensions = np.zeros_like(base_circle) + for i in range(3): + angle = i * 2 * np.pi / 3 + ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) + ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) + ext_radius = radius // 3 + + ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius + extensions = extensions | ext_circle + + mask = (base_circle | extensions).astype(np.uint8) * 255 + + # Apply morphological operations to smooth the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask + +# --- Depth Estimation Functions --- + +def predict_depth(image): + return depth_model.infer_image(image) + +def calculate_max_points(image): + """Calculate maximum points based on image dimensions (3x pixel count)""" + if image is None: + return 10000 # Default value + h, w = image.shape[:2] + max_points = h * w * 3 + # Ensure minimum and reasonable maximum values + return max(1000, min(max_points, 300000)) + +def update_slider_on_image_upload(image): + """Update the points slider when an image is uploaded""" + max_points = calculate_max_points(image) + default_value = min(10000, max_points // 10) # 10% of max points as default + return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, + label=f"Number of 3D points (max: {max_points:,})") + + +def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): + """Create a point cloud from depth map using camera intrinsics with high detail""" + h, w = depth_map.shape + + # Use smaller step for higher detail (reduced downsampling) + step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + x_cam = (x_coords - w / 2) / focal_length_x + y_cam = (y_coords - h / 2) / focal_length_y + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors = image_colors.reshape(-1, 3) / 255.0 + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + + +def reconstruct_surface_mesh_from_point_cloud(pcd): + """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" + # Estimate and orient normals with high precision + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) + pcd.orient_normals_consistent_tangent_plane(k=50) + + # Create surface mesh with maximum detail (depth=12 for very high resolution) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) + + # Return mesh without filtering low-density vertices + return mesh + + +def create_enhanced_3d_visualization(image, depth_map, max_points=10000): + """Create an enhanced 3D visualization using proper camera projection""" + h, w = depth_map.shape + + # Downsample to avoid too many points for performance + step = max(1, int(np.sqrt(h * w / max_points))) + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + focal_length = 470.4 # Default focal length + x_cam = (x_coords - w / 2) / focal_length + y_cam = (y_coords - h / 2) / focal_length + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + x_flat = x_3d.flatten() + y_flat = y_3d.flatten() + z_flat = z_3d.flatten() + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors_flat = image_colors.reshape(-1, 3) + + # Create 3D scatter plot with proper camera projection + fig = go.Figure(data=[go.Scatter3d( + x=x_flat, + y=y_flat, + z=z_flat, + mode='markers', + marker=dict( + size=1.5, + color=colors_flat, + opacity=0.9 + ), + hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + + 'Depth: %{z:.2f}
    ' + + '' + )]) + + fig.update_layout( + title="3D Point Cloud Visualization (Camera Projection)", + scene=dict( + xaxis_title="X (meters)", + yaxis_title="Y (meters)", + zaxis_title="Z (meters)", + camera=dict( + eye=dict(x=2.0, y=2.0, z=2.0), + center=dict(x=0, y=0, z=0), + up=dict(x=0, y=0, z=1) + ), + aspectmode='data' + ), + width=700, + height=600 + ) + + return fig + +def on_depth_submit(image, num_points, focal_x, focal_y): + original_image = image.copy() + + h, w = image.shape[:2] + + # Predict depth using the model + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + + # Save raw 16-bit depth + raw_depth = Image.fromarray(depth.astype('uint16')) + tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + raw_depth.save(tmp_raw_depth.name) + + # Normalize and convert to grayscale for display + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + norm_depth = norm_depth.astype(np.uint8) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) + + gray_depth = Image.fromarray(norm_depth) + tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + gray_depth.save(tmp_gray_depth.name) + + # Create point cloud + pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) + + # Reconstruct mesh from point cloud + mesh = reconstruct_surface_mesh_from_point_cloud(pcd) + + # Save mesh with faces as .ply + tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) + o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) + + # Create enhanced 3D scatter plot visualization + depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) + + return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] + +# --- Actual Wound Segmentation Functions --- +def create_automatic_wound_mask(image, method='deep_learning'): + """ + Automatically generate wound mask from image using the actual deep learning model + + Args: + image: Input image (numpy array) + method: Segmentation method (currently only 'deep_learning' supported) + + Returns: + mask: Binary wound mask + """ + if image is None: + return None + + # Use the actual deep learning model for segmentation + if method == 'deep_learning': + mask, _ = segmentation_model.segment_wound(image) + return mask + else: + # Fallback to deep learning if method not recognized + mask, _ = segmentation_model.segment_wound(image) + return mask + +def post_process_wound_mask(mask, min_area=100): + """Post-process the wound mask to remove noise and small objects""" + if mask is None: + return None + + # Convert to binary if needed + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + + # Apply morphological operations to clean up + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Remove small objects using OpenCV + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + + for contour in contours: + area = cv2.contourArea(contour) + if area >= min_area: + cv2.fillPoly(mask_clean, [contour], 255) + + # Fill holes + mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) + + return mask_clean + +def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): + """Analyze wound severity with automatic mask generation using actual segmentation model""" + if image is None or depth_map is None: + return "❌ Please provide both image and depth map." + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method=segmentation_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." + + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) + +# --- Main Gradio Interface --- +with gr.Blocks(css=css, title="Wound Analysis System") as demo: + gr.HTML("

    Wound Analysis System

    ") + #gr.Markdown("### Complete workflow: Classification → Depth Estimation → Wound Severity Analysis") + + # Shared states + shared_image = gr.State() + shared_depth_map = gr.State() + + with gr.Tabs(): + + # Tab 1: Wound Classification + with gr.Tab("1. 🔍 Wound Classification & Initial Analysis"): + gr.Markdown("### Step 1: Classify wound type and get initial AI analysis") + #gr.Markdown("Upload an image to identify the wound type and receive detailed analysis from our Vision AI.") + + + with gr.Row(): + # Left Column - Image Upload + with gr.Column(scale=1): + gr.HTML('

    Upload Wound Image

    ') + classification_image_input = gr.Image( + label="", + type="pil", + height=400 + ) + # Place Clear and Analyse buttons side by side + with gr.Row(): + classify_clear_btn = gr.Button( + "Clear", + variant="secondary", + size="lg", + scale=1 + ) + analyse_btn = gr.Button( + "Analyse", + variant="primary", + size="lg", + scale=1 + ) + # Right Column - Classification Results + with gr.Column(scale=1): + gr.HTML('

    Classification Results

    ') + classification_output = gr.Label( + label="", + num_top_classes=5, + show_label=False + ) + + # Second Row - Full Width AI Analysis + with gr.Row(): + with gr.Column(scale=1): + gr.HTML('

    Wound Visual Analysis

    ') + gemini_output = gr.HTML( + value=""" +
    + Upload an image to get AI-powered wound analysis +
    + """ + ) + + # Event handlers for classification tab + classify_clear_btn.click( + fn=lambda: (None, None, """ +
    + Upload an image to get AI-powered wound analysis +
    + """), + inputs=None, + outputs=[classification_image_input, classification_output, gemini_output] + ) + + # Only run classification on image upload + def classify_and_store(image): + result = classify_wound(image) + return result + + classification_image_input.change( + fn=classify_and_store, + inputs=classification_image_input, + outputs=classification_output + ) + + # Store image in shared state for next tabs + def store_shared_image(image): + return image + + classification_image_input.change( + fn=store_shared_image, + inputs=classification_image_input, + outputs=shared_image + ) + + # Run Gemini analysis only when Analyse button is clicked + def run_gemini_on_click(image, classification): + # Get top label + if isinstance(classification, dict) and classification: + top_label = max(classification.items(), key=lambda x: x[1])[0] + else: + top_label = "Unknown" + gemini_analysis = analyze_wound_with_gemini(image, top_label) + formatted_analysis = format_gemini_analysis(gemini_analysis) + return formatted_analysis + + analyse_btn.click( + fn=run_gemini_on_click, + inputs=[classification_image_input, classification_output], + outputs=gemini_output + ) + + # Tab 2: Depth Estimation + with gr.Tab("2. 📏 Depth Estimation & 3D Visualization"): + gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") + gr.Markdown("This module creates depth maps and 3D point clouds from your images.") + + with gr.Row(): + load_from_classification_btn = gr.Button("🔄 Load Image from Classification Tab", variant="secondary") + + with gr.Row(): + depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') + + with gr.Row(): + depth_submit = gr.Button(value="Compute Depth", variant="primary") + + points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, + label="Number of 3D points (upload image to update max)") + + with gr.Row(): + focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length X (pixels)") + focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length Y (pixels)") + + # Reorganized layout: 2 columns - 3D visualization on left, file outputs stacked on right + with gr.Row(): + with gr.Column(scale=2): + # 3D Visualization + gr.Markdown("### 3D Point Cloud Visualization") + gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") + depth_3d_plot = gr.Plot(label="3D Point Cloud") + + with gr.Column(scale=1): + gr.Markdown("### Download Files") + gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") + raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") + point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") + + + + # Tab 3: Wound Severity Analysis + with gr.Tab("3. 🩹 Wound Severity Analysis"): + gr.Markdown("### Step 3: Analyze wound severity using depth maps") + gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") + + with gr.Row(): + # Load depth map from previous tab + load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") + + with gr.Row(): + severity_input_image = gr.Image(label="Original Image", type='numpy') + severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') + + with gr.Row(): + wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') + + with gr.Row(): + severity_output = gr.HTML( + label="🤖 AI-Powered Medical Assessment", + value=""" +
    +
    + 🩹 Wound Severity Analysis +
    +
    + ⏳ Waiting for Input... +
    +
    + Please upload an image and depth map, then click "🤖 Analyze Severity with Auto-Generated Mask" to begin AI-powered medical assessment. +
    +
    + """ + ) + + gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") + + with gr.Row(): + auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") + pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, + label="Pixel Spacing (mm/pixel)") + depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0, + label="Depth Calibration (mm)", + info="Adjust based on expected maximum wound depth") + + #gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") + #gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.") + + #gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") + + # Update slider when image is uploaded + depth_input_image.change( + fn=update_slider_on_image_upload, + inputs=[depth_input_image], + outputs=[points_slider] + ) + + # Modified depth submit function to store depth map + def on_depth_submit_with_state(image, num_points, focal_x, focal_y): + results = on_depth_submit(image, num_points, focal_x, focal_y) + # Extract depth map from results for severity analysis + depth_map = None + if image is not None: + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + # Normalize depth for severity analysis + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth_map = norm_depth.astype(np.uint8) + return results + [depth_map] + + depth_submit.click(on_depth_submit_with_state, + inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], + outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, shared_depth_map]) + + # Function to load image from classification to depth tab + def load_image_from_classification(shared_img): + if shared_img is None: + return None, "❌ No image available from classification tab. Please upload an image in Tab 1 first." + + # Convert PIL image to numpy array for depth estimation + if hasattr(shared_img, 'convert'): + # It's a PIL image, convert to numpy + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab successfully!" + else: + # Already numpy array + return shared_img, "✅ Image loaded from classification tab successfully!" + + # Connect the load button + load_from_classification_btn.click( + fn=load_image_from_classification, + inputs=shared_image, + outputs=[depth_input_image, gr.HTML()] + ) + + # Load depth map to severity tab and auto-generate mask + def load_depth_to_severity(depth_map, original_image): + if depth_map is None: + return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first." + + # Auto-generate wound mask using segmentation model + if original_image is not None: + auto_mask, _ = segmentation_model.segment_wound(original_image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!" + else: + return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded successfully!" + + load_depth_btn.click( + fn=load_depth_to_severity, + inputs=[shared_depth_map, depth_input_image], + outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] + ) + + # Loading state function + def show_loading_state(): + return """ +
    +
    + 🩹 Wound Severity Analysis +
    +
    + 🔄 AI Analysis in Progress... +
    +
    + • Generating wound mask with deep learning model
    + • Computing depth measurements and statistics
    + • Analyzing wound characteristics with Gemini AI
    + • Preparing comprehensive medical assessment +
    +
    + +
    + """ + + # Automatic severity analysis function + def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration): + if depth_map is None: + return """ +
    +
    + ❌ Error +
    +
    + Please load depth map from Tab 1 first. +
    +
    + """ + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method='deep_learning') + + if auto_mask is None: + return """ +
    +
    + ❌ Error +
    +
    + Failed to generate automatic wound mask. Please check if the segmentation model is loaded. +
    +
    + """ + + # Post-process the mask with fixed minimum area + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return """ +
    +
    + ⚠️ No Wound Detected +
    +
    + No wound region detected by the segmentation model. Try uploading a different image or use manual mask. +
    +
    + """ + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration) + + # Connect event handler with loading state + auto_severity_button.click( + fn=show_loading_state, + inputs=[], + outputs=[severity_output] + ).then( + fn=run_auto_severity_analysis, + inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider], + outputs=[severity_output] + ) + + + + # Auto-generate mask when image is uploaded + def auto_generate_mask_on_image_upload(image): + if image is None: + return None, "❌ No image uploaded." + + # Generate automatic wound mask using segmentation model + auto_mask, _ = segmentation_model.segment_wound(image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return processed_mask, "✅ Wound mask auto-generated using deep learning model!" + else: + return None, "✅ Image uploaded but no wound detected. Try uploading a different image." + else: + return None, "✅ Image uploaded but segmentation failed. Try uploading a different image." + + # Load shared image from classification tab + def load_shared_image(shared_img): + if shared_img is None: + return gr.Image(), "❌ No image available from classification tab" + + # Convert PIL image to numpy array for depth estimation + if hasattr(shared_img, 'convert'): + # It's a PIL image, convert to numpy + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab" + else: + # Already numpy array + return shared_img, "✅ Image loaded from classification tab" + + # Auto-generate mask when image is uploaded to severity tab + severity_input_image.change( + fn=auto_generate_mask_on_image_upload, + inputs=[severity_input_image], + outputs=[wound_mask_input, gr.HTML()] + ) + + + +if __name__ == '__main__': + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True + ) \ No newline at end of file diff --git a/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc b/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..320c66e3854de2ab05a7ec6787958fc9a7bdef84 Binary files /dev/null and b/depth_anything_v2/__pycache__/dinov2.cpython-310.pyc differ diff --git a/depth_anything_v2/__pycache__/dpt.cpython-310.pyc b/depth_anything_v2/__pycache__/dpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca627d373cf39e8a744e3c402fa0498348cf9d49 Binary files /dev/null and b/depth_anything_v2/__pycache__/dpt.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2.py b/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbfc7d24d37796d5310fd966b582bb3773685dc --- /dev/null +++ b/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/depth_anything_v2/dinov2_layers/__init__.py b/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e59a83eb90512d763b03e4d38536b6ae07e87541 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0efe36e1b6abcdeff4f7611f0cb096ba8606f6f4 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d1812bf2e59d6d5796ee57b16bea7364413731 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322578e8683fdb1a35bfc8e23af92a87e4546b02 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3bd2ebcf2ead79e30a75a3ecd9b974d92310083 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52af677deee30b6e536f6369647598cc1c089ec9 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea50613125836d15598205f316241a06e12e727 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b92b1b42a848aeb9b26363664e670e99c1593c8 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbad238bb2fadccf537cb42e2b77cab95f01482 Binary files /dev/null and b/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/depth_anything_v2/dinov2_layers/attention.py b/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..dea0c82d55f052bf4bcb5896ad8c37158ef523d5 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/depth_anything_v2/dinov2_layers/block.py b/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..f91f3f07bd15fba91c67068c8dce2bb22d505bf7 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/depth_anything_v2/dinov2_layers/drop_path.py b/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..10c3bea8e40eec258bbe59087770d230a6375481 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/depth_anything_v2/dinov2_layers/layer_scale.py b/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..76a4d0eedb1dc974a45e06fbe77ff3d909e36e55 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/depth_anything_v2/dinov2_layers/mlp.py b/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..504987b635c9cd582a352fb2381228c9e6cd043c --- /dev/null +++ b/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/depth_anything_v2/dinov2_layers/patch_embed.py b/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..f880c042ee6a33ef520c6a8c8a686c1d065b8f49 --- /dev/null +++ b/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..155a3dd9f6f1a7d0f7bdf9c8f1981e58acb3b19c --- /dev/null +++ b/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/depth_anything_v2/dpt.py b/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..acef20bfcf80318709dcf6c5e8c19b117394a06b --- /dev/null +++ b/depth_anything_v2/dpt.py @@ -0,0 +1,221 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.relu(depth) + + return depth.squeeze(1) + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + + depth = self.forward(image) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc b/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..580ca2d9a11546d5bafd4dadc9600e91c7bf3b11 Binary files /dev/null and b/depth_anything_v2/util/__pycache__/blocks.cpython-310.pyc differ diff --git a/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc b/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d265f0d3c4dcb77e24f86929a3fc602e5e30ffcc Binary files /dev/null and b/depth_anything_v2/util/__pycache__/transform.cpython-310.pyc differ diff --git a/depth_anything_v2/util/blocks.py b/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb66c03702d653f411c59ab9966916c348c7c6e --- /dev/null +++ b/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/depth_anything_v2/util/transform.py b/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..1cce234c86177e1ad5c84c81c7c1afb16877c9da --- /dev/null +++ b/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/models/FCN.py b/models/FCN.py new file mode 100644 index 0000000000000000000000000000000000000000..5c69feca49bd6919e5ead7f57f60db06cee826cc --- /dev/null +++ b/models/FCN.py @@ -0,0 +1,55 @@ +import os +from keras.models import Model +from keras.layers import Input +from keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D +from utils.BilinearUpSampling import BilinearUpSampling2D + + +def FCN_Vgg16_16s(input_shape=None, weight_decay=0., batch_momentum=0.9, batch_shape=None, classes=1): + if batch_shape: + img_input = Input(batch_shape=batch_shape) + image_size = batch_shape[1:3] + else: + img_input = Input(shape=input_shape) + image_size = input_shape[0:2] + # Block 1 + x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer='l2')(img_input) + x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer='l2')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) + + # Block 2 + x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer='l2')(x) + x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer='l2')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) + + # Block 3 + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer='l2')(x) + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer='l2')(x) + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer='l2')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) + + # Block 4 + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer='l2')(x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer='l2')(x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer='l2')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) + + # Block 5 + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer='l2')(x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer='l2')(x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer='l2')(x) + + # Convolutional layers transfered from fully-connected layers + x = Conv2D(4096, (7, 7), activation='relu', padding='same', dilation_rate=(2, 2), + name='fc1', kernel_regularizer='l2')(x) + x = Dropout(0.5)(x) + x = Conv2D(4096, (1, 1), activation='relu', padding='same', name='fc2', kernel_regularizer='l2')(x) + x = Dropout(0.5)(x) + #classifying layer + x = Conv2D(classes, (1, 1), kernel_initializer='he_normal', activation='linear', padding='valid', strides=(1, 1), kernel_regularizer='l2')(x) + + x = BilinearUpSampling2D(size=(16, 16))(x) + + model = Model(img_input, x) + model_name = 'FCN_Vgg16_16' + return model, model_name \ No newline at end of file diff --git a/models/SegNet.py b/models/SegNet.py new file mode 100644 index 0000000000000000000000000000000000000000..650585c3185ec3494977ebe790973711650296f3 --- /dev/null +++ b/models/SegNet.py @@ -0,0 +1,33 @@ +from keras.models import Model +from keras.layers import Input +from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D + + +class SegNet: + def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels): + self.input_dim_x = input_dim_x + self.input_dim_y = input_dim_y + self.n_filters = n_filters + self.num_channels = num_channels + + def get_SegNet(self): + convnet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels)) + + encoder_conv1 = Conv2D(self.n_filters, kernel_size=9, activation='relu', padding='same')(convnet_input) + pool1 = MaxPooling2D(pool_size=(2, 2))(encoder_conv1) + encoder_conv2 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool1) + pool2 = MaxPooling2D(pool_size=(2, 2))(encoder_conv2) + encoder_conv3 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool2) + pool3 = MaxPooling2D(pool_size=(2, 2))(encoder_conv3) + encoder_conv4 = Conv2D(self.n_filters * 2, kernel_size=5, activation='relu', padding='same')(pool3) + pool4 = MaxPooling2D(pool_size=(2, 2))(encoder_conv4) + + conv5 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(pool4) + + decoder_conv6 = Conv2D(self.n_filters, kernel_size=7, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5)) + decoder_conv7 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv6)) + decoder_conv8 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv7)) + #decoder_conv9 = Conv2D(self.n_filters, kernel_size=5, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8)) + decoder_conv9 = Conv2D(1, kernel_size=1, activation='sigmoid', padding='same')(UpSampling2D(size=(2, 2))(decoder_conv8)) + + return Model(outputs=decoder_conv9, inputs=convnet_input), 'SegNet' \ No newline at end of file diff --git a/models/__pycache__/FCN.cpython-37.pyc b/models/__pycache__/FCN.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efdb47a204f521261dadb9743dcca856b522a9f Binary files /dev/null and b/models/__pycache__/FCN.cpython-37.pyc differ diff --git a/models/__pycache__/FCN.cpython-39.pyc b/models/__pycache__/FCN.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735f22cedf94b409b7347b6f9fc09ad89a5d7348 Binary files /dev/null and b/models/__pycache__/FCN.cpython-39.pyc differ diff --git a/models/__pycache__/SegNet.cpython-37.pyc b/models/__pycache__/SegNet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ebd187de4eb22db3c0d18a9dc966ec1d0880a94 Binary files /dev/null and b/models/__pycache__/SegNet.cpython-37.pyc differ diff --git a/models/__pycache__/SegNet.cpython-39.pyc b/models/__pycache__/SegNet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7816b13e59765cce0b3cc6c4c53720c932ba5ce Binary files /dev/null and b/models/__pycache__/SegNet.cpython-39.pyc differ diff --git a/models/__pycache__/deeplab.cpython-310.pyc b/models/__pycache__/deeplab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad99672800379db699d7dec0668711e21fa14e29 Binary files /dev/null and b/models/__pycache__/deeplab.cpython-310.pyc differ diff --git a/models/__pycache__/deeplab.cpython-313.pyc b/models/__pycache__/deeplab.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a71ee9a65fd1f578047531d9aea9a4e03c8269a3 Binary files /dev/null and b/models/__pycache__/deeplab.cpython-313.pyc differ diff --git a/models/__pycache__/deeplab.cpython-37.pyc b/models/__pycache__/deeplab.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..889940f7b30762d822089c0c8023f0b636235d2e Binary files /dev/null and b/models/__pycache__/deeplab.cpython-37.pyc differ diff --git a/models/__pycache__/deeplab.cpython-39.pyc b/models/__pycache__/deeplab.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84d996fdf13f9c184ec09947f490e02d3f9315f8 Binary files /dev/null and b/models/__pycache__/deeplab.cpython-39.pyc differ diff --git a/models/__pycache__/unets.cpython-37.pyc b/models/__pycache__/unets.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6876f7142f039468b23019771dbcfd9a48874762 Binary files /dev/null and b/models/__pycache__/unets.cpython-37.pyc differ diff --git a/models/__pycache__/unets.cpython-39.pyc b/models/__pycache__/unets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e2d40a7e10376169b03e12094acd95c25a5a700 Binary files /dev/null and b/models/__pycache__/unets.cpython-39.pyc differ diff --git a/models/deeplab.py b/models/deeplab.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc53345d50c8fedfd693b326258dba2a7e7ebbd --- /dev/null +++ b/models/deeplab.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- + +""" Deeplabv3+ model for Keras. +This model is based on this repo: +https://github.com/bonlime/keras-deeplab-v3-plus + +MobileNetv2 backbone is based on this repo: +https://github.com/JonathanCMitchell/mobilenet_v2_keras + +# Reference +- [Encoder-Decoder with Atrous Separable Convolution + for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf) +- [Xception: Deep Learning with Depthwise Separable Convolutions] + (https://arxiv.org/abs/1610.02357) +- [Inverted Residuals and Linear Bottlenecks: Mobile Networks for + Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from keras.models import Model +from keras import layers +from keras.layers import Input +from keras.layers import Activation +from keras.layers import Concatenate +from keras.layers import Add +from keras.layers import Dropout +from keras.layers import BatchNormalization +from keras.layers import Conv2D +from keras.layers import DepthwiseConv2D +from keras.layers import ZeroPadding2D +from keras.layers import AveragePooling2D +from keras.layers import Layer +from tensorflow.keras.layers import InputSpec +from tensorflow.keras.utils import get_source_inputs +from keras import backend as K +from keras.applications import imagenet_utils +from keras.utils import conv_utils +from keras.utils.data_utils import get_file + +WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5" +WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5" +WEIGHTS_PATH_X_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5" +WEIGHTS_PATH_MOBILE_CS = "https://github.com/rdiazgar/keras-deeplab-v3-plus/releases/download/1.2/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5" + +class BilinearUpsampling(Layer): + """Just a simple bilinear upsampling layer. Works only with TF. + Args: + upsampling: tuple of 2 numbers > 0. The upsampling ratio for h and w + output_size: used instead of upsampling arg if passed! + """ + + def __init__(self, upsampling=(2, 2), output_size=None, data_format=None, **kwargs): + + super(BilinearUpsampling, self).__init__(**kwargs) + + self.data_format = K.image_data_format() + self.input_spec = InputSpec(ndim=4) + if output_size: + self.output_size = conv_utils.normalize_tuple( + output_size, 2, 'output_size') + self.upsampling = None + else: + self.output_size = None + self.upsampling = conv_utils.normalize_tuple( + upsampling, 2, 'upsampling') + + def compute_output_shape(self, input_shape): + if self.upsampling: + height = self.upsampling[0] * \ + input_shape[1] if input_shape[1] is not None else None + width = self.upsampling[1] * \ + input_shape[2] if input_shape[2] is not None else None + else: + height = self.output_size[0] + width = self.output_size[1] + return (input_shape[0], + height, + width, + input_shape[3]) + + def call(self, inputs): + if self.upsampling: + return tf.compat.v1.image.resize_bilinear(inputs, (inputs.shape[1] * self.upsampling[0], + inputs.shape[2] * self.upsampling[1]), + align_corners=True) + else: + return tf.compat.v1.image.resize_bilinear(inputs, (self.output_size[0], + self.output_size[1]), + align_corners=True) + + def get_config(self): + config = {'upsampling': self.upsampling, + 'output_size': self.output_size, + 'data_format': self.data_format} + base_config = super(BilinearUpsampling, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def SepConv_BN(x, filters, prefix, stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3): + """ SepConv with BN between depthwise & pointwise. Optionally add activation after BN + Implements right "same" padding for even kernel sizes + Args: + x: input tensor + filters: num of filters in pointwise convolution + prefix: prefix before name + stride: stride at depthwise conv + kernel_size: kernel size for depthwise convolution + rate: atrous rate for depthwise convolution + depth_activation: flag to use activation between depthwise & poinwise convs + epsilon: epsilon to use in BN layer + """ + + if stride == 1: + depth_padding = 'same' + else: + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + x = ZeroPadding2D((pad_beg, pad_end))(x) + depth_padding = 'valid' + + if not depth_activation: + x = Activation('relu')(x) + x = DepthwiseConv2D((kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate), + padding=depth_padding, use_bias=False, name=prefix + '_depthwise')(x) + x = BatchNormalization(name=prefix + '_depthwise_BN', epsilon=epsilon)(x) + if depth_activation: + x = Activation('relu')(x) + x = Conv2D(filters, (1, 1), padding='same', + use_bias=False, name=prefix + '_pointwise')(x) + x = BatchNormalization(name=prefix + '_pointwise_BN', epsilon=epsilon)(x) + if depth_activation: + x = Activation('relu')(x) + + return x + + +def _conv2d_same(x, filters, prefix, stride=1, kernel_size=3, rate=1): + """Implements right 'same' padding for even kernel sizes + Without this there is a 1 pixel drift when stride = 2 + Args: + x: input tensor + filters: num of filters in pointwise convolution + prefix: prefix before name + stride: stride at depthwise conv + kernel_size: kernel size for depthwise convolution + rate: atrous rate for depthwise convolution + """ + if stride == 1: + return Conv2D(filters, + (kernel_size, kernel_size), + strides=(stride, stride), + padding='same', use_bias=False, + dilation_rate=(rate, rate), + name=prefix)(x) + else: + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + x = ZeroPadding2D((pad_beg, pad_end))(x) + return Conv2D(filters, + (kernel_size, kernel_size), + strides=(stride, stride), + padding='valid', use_bias=False, + dilation_rate=(rate, rate), + name=prefix)(x) + + +def _xception_block(inputs, depth_list, prefix, skip_connection_type, stride, + rate=1, depth_activation=False, return_skip=False): + """ Basic building block of modified Xception network + Args: + inputs: input tensor + depth_list: number of filters in each SepConv layer. len(depth_list) == 3 + prefix: prefix before name + skip_connection_type: one of {'conv','sum','none'} + stride: stride at last depthwise conv + rate: atrous rate for depthwise convolution + depth_activation: flag to use activation between depthwise & pointwise convs + return_skip: flag to return additional tensor after 2 SepConvs for decoder + """ + residual = inputs + for i in range(3): + residual = SepConv_BN(residual, + depth_list[i], + prefix + '_separable_conv{}'.format(i + 1), + stride=stride if i == 2 else 1, + rate=rate, + depth_activation=depth_activation) + if i == 1: + skip = residual + if skip_connection_type == 'conv': + shortcut = _conv2d_same(inputs, depth_list[-1], prefix + '_shortcut', + kernel_size=1, + stride=stride) + shortcut = BatchNormalization(name=prefix + '_shortcut_BN')(shortcut) + outputs = layers.add([residual, shortcut]) + elif skip_connection_type == 'sum': + outputs = layers.add([residual, inputs]) + elif skip_connection_type == 'none': + outputs = residual + if return_skip: + return outputs, skip + else: + return outputs + + +def relu6(x): + return K.relu(x, max_value=6) + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1): + in_channels = inputs.shape[-1] + pointwise_conv_filters = int(filters * alpha) + pointwise_filters = _make_divisible(pointwise_conv_filters, 8) + x = inputs + prefix = 'expanded_conv_{}_'.format(block_id) + if block_id: + # Expand + + x = Conv2D(expansion * in_channels, kernel_size=1, padding='same', + use_bias=False, activation=None, + name=prefix + 'expand')(x) + x = BatchNormalization(epsilon=1e-3, momentum=0.999, + name=prefix + 'expand_BN')(x) + x = Activation(relu6, name=prefix + 'expand_relu')(x) + else: + prefix = 'expanded_conv_' + # Depthwise + x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None, + use_bias=False, padding='same', dilation_rate=(rate, rate), + name=prefix + 'depthwise')(x) + x = BatchNormalization(epsilon=1e-3, momentum=0.999, + name=prefix + 'depthwise_BN')(x) + + x = Activation(relu6, name=prefix + 'depthwise_relu')(x) + + # Project + x = Conv2D(pointwise_filters, + kernel_size=1, padding='same', use_bias=False, activation=None, + name=prefix + 'project')(x) + x = BatchNormalization(epsilon=1e-3, momentum=0.999, + name=prefix + 'project_BN')(x) + + if skip_connection: + return Add(name=prefix + 'add')([inputs, x]) + + # if in_channels == pointwise_filters and stride == 1: + # return Add(name='res_connect_' + str(block_id))([inputs, x]) + + return x + + +def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone='mobilenetv2' + , OS=16, alpha=1.): + """ Instantiates the Deeplabv3+ architecture + + Optionally loads weights pre-trained + on PASCAL VOC. This model is available for TensorFlow only, + and can only be used with inputs following the TensorFlow + data format `(width, height, channels)`. + # Arguments + weights: one of 'pascal_voc' (pre-trained on pascal voc) + or None (random initialization) + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: shape of input image. format HxWxC + PASCAL VOC model was trained on (512,512,3) images + classes: number of desired classes. If classes != 21, + last layer is initialized randomly + backbone: backbone to use. one of {'xception','mobilenetv2'} + OS: determines input_shape/feature_extractor_output ratio. One of {8,16}. + Used only for xception backbone. + alpha: controls the width of the MobileNetV2 network. This is known as the + width multiplier in the MobileNetV2 paper. + - If `alpha` < 1.0, proportionally decreases the number + of filters in each layer. + - If `alpha` > 1.0, proportionally increases the number + of filters in each layer. + - If `alpha` = 1, default number of filters from the paper + are used at each layer. + Used only for mobilenetv2 backbone + + # Returns + A Keras model instance. + + # Raises + RuntimeError: If attempting to run this model with a + backend that does not support separable convolutions. + ValueError: in case of invalid argument for `weights` or `backbone` + + """ + + if not (weights in {'pascal_voc', 'cityscapes', None}): + raise ValueError('The `weights` argument should be either ' + '`None` (random initialization), `pascal_voc`, or `cityscapes` ' + '(pre-trained on PASCAL VOC)') + + if K.backend() != 'tensorflow': + raise RuntimeError('The Deeplabv3+ model is only available with ' + 'the TensorFlow backend.') + + if not (backbone in {'xception', 'mobilenetv2'}): + raise ValueError('The `backbone` argument should be either ' + '`xception` or `mobilenetv2` ') + + if input_tensor is None: + img_input = Input(shape=input_shape) + else: + if not K.is_keras_tensor(input_tensor): + # Input layer + img_input = Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if backbone == 'xception': + if OS == 8: + entry_block3_stride = 1 + middle_block_rate = 2 # ! Not mentioned in paper, but required + exit_block_rates = (2, 4) + atrous_rates = (12, 24, 36) + else: + entry_block3_stride = 2 + middle_block_rate = 1 + exit_block_rates = (1, 2) + atrous_rates = (6, 12, 18) + + x = Conv2D(32, (3, 3), strides=(2, 2), + name='entry_flow_conv1_1', use_bias=False, padding='same')(img_input) + x = BatchNormalization(name='entry_flow_conv1_1_BN')(x) + x = Activation('relu')(x) + + x = _conv2d_same(x, 64, 'entry_flow_conv1_2', kernel_size=3, stride=1) + x = BatchNormalization(name='entry_flow_conv1_2_BN')(x) + x = Activation('relu')(x) + + x = _xception_block(x, [128, 128, 128], 'entry_flow_block1', + skip_connection_type='conv', stride=2, + depth_activation=False) + x, skip1 = _xception_block(x, [256, 256, 256], 'entry_flow_block2', + skip_connection_type='conv', stride=2, + depth_activation=False, return_skip=True) + + x = _xception_block(x, [728, 728, 728], 'entry_flow_block3', + skip_connection_type='conv', stride=entry_block3_stride, + depth_activation=False) + for i in range(16): + x = _xception_block(x, [728, 728, 728], 'middle_flow_unit_{}'.format(i + 1), + skip_connection_type='sum', stride=1, rate=middle_block_rate, + depth_activation=False) + + x = _xception_block(x, [728, 1024, 1024], 'exit_flow_block1', + skip_connection_type='conv', stride=1, rate=exit_block_rates[0], + depth_activation=False) + x = _xception_block(x, [1536, 1536, 2048], 'exit_flow_block2', + skip_connection_type='none', stride=1, rate=exit_block_rates[1], + depth_activation=True) + + else: + OS = 8 + first_block_filters = _make_divisible(32 * alpha, 8) + x = Conv2D(first_block_filters, + kernel_size=3, + strides=(2, 2), padding='same', + use_bias=False, name='Conv')(img_input) + x = BatchNormalization( + epsilon=1e-3, momentum=0.999, name='Conv_BN')(x) + x = Activation(relu6, name='Conv_Relu6')(x) + + x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1, + expansion=1, block_id=0, skip_connection=False) + + x = _inverted_res_block(x, filters=24, alpha=alpha, stride=2, + expansion=6, block_id=1, skip_connection=False) + x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, + expansion=6, block_id=2, skip_connection=True) + + x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2, + expansion=6, block_id=3, skip_connection=False) + x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, + expansion=6, block_id=4, skip_connection=True) + x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, + expansion=6, block_id=5, skip_connection=True) + + # stride in block 6 changed from 2 -> 1, so we need to use rate = 2 + x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, # 1! + expansion=6, block_id=6, skip_connection=False) + x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=7, skip_connection=True) + x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=8, skip_connection=True) + x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=9, skip_connection=True) + + x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=10, skip_connection=False) + x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=11, skip_connection=True) + x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, rate=2, + expansion=6, block_id=12, skip_connection=True) + + x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=2, # 1! + expansion=6, block_id=13, skip_connection=False) + x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4, + expansion=6, block_id=14, skip_connection=True) + x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, rate=4, + expansion=6, block_id=15, skip_connection=True) + + x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, rate=4, + expansion=6, block_id=16, skip_connection=False) + + # end of feature extractor + + # branching for Atrous Spatial Pyramid Pooling + + # Image Feature branch + #out_shape = int(np.ceil(input_shape[0] / OS)) + b4 = AveragePooling2D(pool_size=(int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(x) + b4 = Conv2D(256, (1, 1), padding='same', + use_bias=False, name='image_pooling')(b4) + b4 = BatchNormalization(name='image_pooling_BN', epsilon=1e-5)(b4) + b4 = Activation('relu')(b4) + b4 = BilinearUpsampling((int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))))(b4) + + # simple 1x1 + b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp0')(x) + b0 = BatchNormalization(name='aspp0_BN', epsilon=1e-5)(b0) + b0 = Activation('relu', name='aspp0_activation')(b0) + + # there are only 2 branches in mobilenetV2. not sure why + if backbone == 'xception': + # rate = 6 (12) + b1 = SepConv_BN(x, 256, 'aspp1', + rate=atrous_rates[0], depth_activation=True, epsilon=1e-5) + # rate = 12 (24) + b2 = SepConv_BN(x, 256, 'aspp2', + rate=atrous_rates[1], depth_activation=True, epsilon=1e-5) + # rate = 18 (36) + b3 = SepConv_BN(x, 256, 'aspp3', + rate=atrous_rates[2], depth_activation=True, epsilon=1e-5) + + # concatenate ASPP branches & project + x = Concatenate()([b4, b0, b1, b2, b3]) + else: + x = Concatenate()([b4, b0]) + + x = Conv2D(256, (1, 1), padding='same', + use_bias=False, name='concat_projection')(x) + x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x) + x = Activation('relu')(x) + x = Dropout(0.1)(x) + + # DeepLab v.3+ decoder + + if backbone == 'xception': + # Feature projection + # x4 (x2) block + x = BilinearUpsampling(output_size=(int(np.ceil(input_shape[0] / 4)), + int(np.ceil(input_shape[1] / 4))))(x) + dec_skip1 = Conv2D(48, (1, 1), padding='same', + use_bias=False, name='feature_projection0')(skip1) + dec_skip1 = BatchNormalization( + name='feature_projection0_BN', epsilon=1e-5)(dec_skip1) + dec_skip1 = Activation('relu')(dec_skip1) + x = Concatenate()([x, dec_skip1]) + x = SepConv_BN(x, 256, 'decoder_conv0', + depth_activation=True, epsilon=1e-5) + x = SepConv_BN(x, 256, 'decoder_conv1', + depth_activation=True, epsilon=1e-5) + + # you can use it with arbitary number of classes + if classes == 21: + last_layer_name = 'logits_semantic' + else: + last_layer_name = 'custom_logits_semantic' + + x = Conv2D(classes, (1, 1), padding='same', name=last_layer_name)(x) + x = BilinearUpsampling(output_size=(input_shape[0], input_shape[1]))(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = get_source_inputs(input_tensor) + else: + inputs = img_input + + model = Model(inputs, x, name='deeplabv3plus') + + # load weights + + if weights == 'pascal_voc': + if backbone == 'xception': + weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels.h5', + WEIGHTS_PATH_X, + cache_subdir='models') + else: + weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5', + WEIGHTS_PATH_MOBILE, + cache_subdir='models') + model.load_weights(weights_path, by_name=True) + elif weights == 'cityscapes': + if backbone == 'xception': + weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels_cityscapes.h5', + WEIGHTS_PATH_X_CS, + cache_subdir='models') + else: + weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels_cityscapes.h5', + WEIGHTS_PATH_MOBILE_CS, + cache_subdir='models') + model.load_weights(weights_path, by_name=True) + return model + + +def preprocess_input(x): + """Preprocesses a numpy array encoding a batch of images. + # Arguments + x: a 4D numpy array consists of RGB values within [0, 255]. + # Returns + Input array scaled to [-1.,1.] + """ + return imagenet_utils.preprocess_input(x, mode='tf') diff --git a/models/unets.py b/models/unets.py new file mode 100644 index 0000000000000000000000000000000000000000..2879efa5bfea642bfa6bfa30ceaecacdb29f69d3 --- /dev/null +++ b/models/unets.py @@ -0,0 +1,171 @@ +from keras.models import Model +from keras.layers import Input +from keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, Concatenate, UpSampling2D + + +class Unet2D: + + def __init__(self, n_filters, input_dim_x, input_dim_y, num_channels): + self.input_dim_x = input_dim_x + self.input_dim_y = input_dim_y + self.n_filters = n_filters + self.num_channels = num_channels + + def get_unet_model_5_levels(self): + unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels)) + + conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input) + conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1) + conv1 = BatchNormalization()(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + + conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(pool1) + conv2 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv2) + conv2 = BatchNormalization()(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + + conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool2) + conv3 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv3) + conv3 = BatchNormalization()(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + + conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool3) + conv4 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv4) + conv4 = BatchNormalization()(conv4) + drop4 = Dropout(0.5)(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) + + conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool4) + conv5 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv5) + conv5 = BatchNormalization()(conv5) + drop5 = Dropout(0.5)(conv5) + + up6 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop5)) + concat6 = Concatenate()([drop4, up6]) + conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat6) + conv6 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv6) + conv6 = BatchNormalization()(conv6) + + up7 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6)) + concat7 = Concatenate()([conv3, up7]) + conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat7) + conv7 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv7) + conv7 = BatchNormalization()(conv7) + + up8 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7)) + concat8 = Concatenate()([conv2, up8]) + conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat8) + conv8 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv8) + conv8 = BatchNormalization()(conv8) + + up9 = Conv2D(self.n_filters*2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8)) + concat9 = Concatenate()([conv1, up9]) + conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(concat9) + conv9 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv9) + conv9 = BatchNormalization()(conv9) + + conv10 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv9) + + return Model(outputs=conv10, inputs=unet_input), 'unet_model_5_levels' + + + def get_unet_model_4_levels(self): + unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels)) + + conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(unet_input) + conv1 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv1) + conv1 = BatchNormalization()(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + + conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(pool1) + conv2 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv2) + conv2 = BatchNormalization()(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + + conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(pool2) + conv3 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv3) + conv3 = BatchNormalization()(conv3) + drop3 = Dropout(0.5)(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(drop3) + + conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(pool3) + conv4 = Conv2D(self.n_filters*16, kernel_size=3, activation='relu', padding='same')(conv4) + conv4 = BatchNormalization()(conv4) + drop4 = Dropout(0.5)(conv4) + + up5 = Conv2D(self.n_filters*16, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(drop4)) + concat5 = Concatenate()([drop3, up5]) + conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(concat5) + conv5 = Conv2D(self.n_filters*8, kernel_size=3, activation='relu', padding='same')(conv5) + conv5 = BatchNormalization()(conv5) + + up6 = Conv2D(self.n_filters*8, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5)) + concat6 = Concatenate()([conv2, up6]) + conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(concat6) + conv6 = Conv2D(self.n_filters*4, kernel_size=3, activation='relu', padding='same')(conv6) + conv6 = BatchNormalization()(conv6) + + up7 = Conv2D(self.n_filters*4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6)) + concat7 = Concatenate()([conv1, up7]) + conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(concat7) + conv7 = Conv2D(self.n_filters*2, kernel_size=3, activation='relu', padding='same')(conv7) + conv7 = BatchNormalization()(conv7) + + conv9 = Conv2D(3, kernel_size=1, activation='sigmoid', padding='same')(conv7) + + return Model(outputs=conv9, inputs=unet_input), 'unet_model_4_levels' + + + def get_unet_model_yuanqing(self): + # Model inspired by https://github.com/yuanqing811/ISIC2018 + unet_input = Input(shape=(self.input_dim_x, self.input_dim_y, self.num_channels)) + + conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(unet_input) + conv1 = Conv2D(self.n_filters, kernel_size=3, activation='relu', padding='same')(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + + conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(pool1) + conv2 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + + conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(pool2) + conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3) + conv3 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + + conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool3) + conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4) + conv4 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv4) + pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) + + conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(pool4) + conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5) + conv5 = Conv2D(self.n_filters * 8, kernel_size=3, activation='relu', padding='same')(conv5) + + up6 = Conv2D(self.n_filters * 4, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv5)) + feature4 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv4) + concat6 = Concatenate()([feature4, up6]) + conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(concat6) + conv6 = Conv2D(self.n_filters * 4, kernel_size=3, activation='relu', padding='same')(conv6) + + up7 = Conv2D(self.n_filters * 2, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv6)) + feature3 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv3) + concat7 = Concatenate()([feature3, up7]) + conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(concat7) + conv7 = Conv2D(self.n_filters * 2, kernel_size=3, activation='relu', padding='same')(conv7) + + up8 = Conv2D(self.n_filters * 1, 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv7)) + feature2 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv2) + concat8 = Concatenate()([feature2, up8]) + conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(concat8) + conv8 = Conv2D(self.n_filters * 1, kernel_size=3, activation='relu', padding='same')(conv8) + + up9 = Conv2D(int(self.n_filters / 2), 2, activation='relu', padding='same')(UpSampling2D(size=(2, 2))(conv8)) + feature1 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv1) + concat9 = Concatenate()([feature1, up9]) + conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(concat9) + conv9 = Conv2D(int(self.n_filters / 2), kernel_size=3, activation='relu', padding='same')(conv9) + conv9 = Conv2D(3, kernel_size=3, activation='relu', padding='same')(conv9) + conv10 = Conv2D(1, kernel_size=1, activation='sigmoid')(conv9) + + return Model(outputs=conv10, inputs=unet_input), 'unet_model_yuanqing' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e5c11d50d7849f326b4c3107dde831783c50d4cb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,151 @@ +absl-py==2.3.1 +aiofiles==24.1.0 +annotated-types==0.7.0 +anyio==4.10.0 +asttokens==3.0.0 +astunparse==1.6.3 +attrs==25.3.0 +beautifulsoup4==4.13.4 +blinker==1.9.0 +Brotli==1.1.0 +cachetools==5.5.2 +certifi==2025.8.3 +charset-normalizer==3.4.2 +click==8.2.1 +colorama==0.4.6 +comm==0.2.3 +ConfigArgParse==1.7.1 +contourpy==1.3.2 +cycler==0.12.1 +dash==3.2.0 +decorator==5.2.1 +exceptiongroup==1.3.0 +executing==2.2.0 +fastapi==0.116.1 +fastjsonschema==2.21.1 +ffmpy==0.6.1 +filelock==3.18.0 +Flask==3.1.1 +flatbuffers==25.2.10 +fonttools==4.59.0 +fsspec==2025.7.0 +gast==0.4.0 +google-generativeai +gdown==5.2.0 +google-auth==2.40.3 +google-auth-oauthlib==0.4.6 +google-pasta==0.2.0 +gradio==5.41.1 +gradio_client==1.11.0 +gradio_imageslider==0.0.20 +groovy==0.1.2 +grpcio==1.74.0 +h11==0.16.0 +h5py==3.14.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.34.3 +idna==3.10 +imageio==2.37.0 +importlib_metadata==8.7.0 +ipython==8.37.0 +ipywidgets==8.1.7 +itsdangerous==2.2.0 +jedi==0.19.2 +Jinja2==3.1.6 +jsonschema==4.25.0 +jsonschema-specifications==2025.4.1 +jupyter_core==5.8.1 +jupyterlab_widgets==3.0.15 +keras==2.10.0 +Keras-Preprocessing==1.1.2 +kiwisolver==1.4.8 +lazy_loader==0.4 +libclang==18.1.1 +Markdown==3.8.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.5 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +mpmath==1.3.0 +narwhals==2.0.1 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.4.2 +numpy==1.26.4 +oauthlib==3.3.1 +open3d==0.19.0 +opencv-python==4.11.0.86 +opt_einsum==3.4.0 +orjson==3.11.1 +packaging==25.0 +pandas==2.3.1 +parso==0.8.4 +pillow==11.3.0 +platformdirs==4.3.8 +plotly==6.2.0 +prompt_toolkit==3.0.51 +protobuf==3.19.6 +psutil==5.9.8 +pure_eval==0.2.3 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydub==0.25.1 +Pygments==2.19.2 +pyparsing==3.2.3 +PySocks==1.7.1 +python-dateutil==2.9.0.post0 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.2 +referencing==0.36.2 +requests==2.32.4 +requests-oauthlib==2.0.0 +retrying==1.4.2 +rich==14.1.0 +rpds-py==0.27.0 +rsa==4.9.1 +ruff==0.12.7 +safehttpx==0.1.6 +scikit-image==0.25.2 +scipy==1.15.3 +semantic-version==2.10.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +spaces==0.39.0 +stack-data==0.6.3 +starlette==0.47.2 +sympy==1.14.0 +tensorboard==2.10.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +tensorflow==2.10.1 +tensorflow-estimator==2.10.0 +tensorflow-hub==0.16.1 +tensorflow-io-gcs-filesystem==0.31.0 +termcolor==3.1.0 +tf-keras==2.15.0 +tifffile==2025.5.10 +tomlkit==0.13.3 +torch==2.8.0 +torchvision==0.23.0 +tqdm==4.67.1 +traitlets==5.14.3 +typer==0.16.0 +typing-inspection==0.4.1 +typing_extensions==4.14.1 +tzdata==2025.2 +urllib3==2.5.0 +uvicorn==0.35.0 +wcwidth==0.2.13 +websockets==15.0.1 +Werkzeug==3.1.3 +widgetsnbextension==4.0.14 +wrapt==1.17.2 +zipp==3.23.0 +transformers \ No newline at end of file diff --git a/temp_files/Final_workig_cpu.txt b/temp_files/Final_workig_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..186da1027d63eb08523dbd57c115344994a0bb7a --- /dev/null +++ b/temp_files/Final_workig_cpu.txt @@ -0,0 +1,1000 @@ +import glob +import gradio as gr +import matplotlib +import numpy as np +from PIL import Image +import torch +import tempfile +from gradio_imageslider import ImageSlider +import plotly.graph_objects as go +import plotly.express as px +import open3d as o3d +from depth_anything_v2.dpt import DepthAnythingV2 +import os +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.keras.preprocessing import image as keras_image +import base64 +from io import BytesIO +import gdown +import spaces +import cv2 + +# Import actual segmentation model components +from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling +from utils.learning.metrics import dice_coef, precision, recall +from utils.io.data import normalize + +# Define path and file ID +checkpoint_dir = "checkpoints" +os.makedirs(checkpoint_dir, exist_ok=True) + +model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") +gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" + +# Download if not already present +if not os.path.exists(model_file): + print("Downloading model from Google Drive...") + gdown.download(gdrive_url, model_file, quiet=False) + +# --- TensorFlow: Check GPU Availability --- +gpus = tf.config.list_physical_devices('GPU') +if gpus: + print("TensorFlow is using GPU") +else: + print("TensorFlow is using CPU") + +# --- Load Wound Classification Model and Class Labels --- +wound_model = load_model("keras_model.h5") +with open("labels.txt", "r") as f: + class_labels = [line.strip().split(maxsplit=1)[1] for line in f] + +# --- Load Actual Wound Segmentation Model --- +class WoundSegmentationModel: + def __init__(self): + self.input_dim_x = 224 + self.input_dim_y = 224 + self.model = None + self.load_model() + + def load_model(self): + """Load the trained wound segmentation model""" + try: + # Try to load the most recent model + weight_file_name = '2025-08-07_16-25-27.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e: + print(f"Error loading segmentation model: {e}") + # Fallback to the older model + try: + weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e2: + print(f"Error loading fallback segmentation model: {e2}") + self.model = None + + def preprocess_image(self, image): + """Preprocess the uploaded image for model input""" + if image is None: + return None + + # Convert to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + # Convert BGR to RGB if needed + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Resize to model input size + image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) + + # Normalize the image + image = image.astype(np.float32) / 255.0 + + # Add batch dimension + image = np.expand_dims(image, axis=0) + + return image + + def postprocess_prediction(self, prediction): + """Postprocess the model prediction""" + # Remove batch dimension + prediction = prediction[0] + + # Apply threshold to get binary mask + threshold = 0.5 + binary_mask = (prediction > threshold).astype(np.uint8) * 255 + + return binary_mask + + def segment_wound(self, input_image): + """Main function to segment wound from uploaded image""" + if self.model is None: + return None, "Error: Segmentation model not loaded. Please check the model files." + + if input_image is None: + return None, "Please upload an image." + + try: + # Preprocess the image + processed_image = self.preprocess_image(input_image) + + if processed_image is None: + return None, "Error processing image." + + # Make prediction + prediction = self.model.predict(processed_image, verbose=0) + + # Postprocess the prediction + segmented_mask = self.postprocess_prediction(prediction) + + return segmented_mask, "Segmentation completed successfully!" + + except Exception as e: + return None, f"Error during segmentation: {str(e)}" + +# Initialize the segmentation model +segmentation_model = WoundSegmentationModel() + +# --- PyTorch: Set Device and Load Depth Model --- +map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") +print(f"Using PyTorch device: {map_device}") + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} +encoder = 'vitl' +depth_model = DepthAnythingV2(**model_configs[encoder]) +state_dict = torch.load( + f'checkpoints/depth_anything_v2_{encoder}.pth', + map_location=map_device +) +depth_model.load_state_dict(state_dict) +depth_model = depth_model.to(map_device).eval() + + +# --- Custom CSS for unified dark theme --- +css = """ +.gradio-container { + font-family: 'Segoe UI', sans-serif; + background-color: #121212; + color: #ffffff; + padding: 20px; +} +.gr-button { + background-color: #2c3e50; + color: white; + border-radius: 10px; +} +.gr-button:hover { + background-color: #34495e; +} +.gr-html, .gr-html div { + white-space: normal !important; + overflow: visible !important; + text-overflow: unset !important; + word-break: break-word !important; +} +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 80vh; +} +#img-display-output { + max-height: 80vh; +} +#download { + height: 62px; +} +h1 { + text-align: center; + font-size: 3rem; + font-weight: bold; + margin: 2rem 0; + color: #ffffff; +} +h2 { + color: #ffffff; + text-align: center; + margin: 1rem 0; +} +.gr-tabs { + background-color: #1e1e1e; + border-radius: 10px; + padding: 10px; +} +.gr-tab-nav { + background-color: #2c3e50; + border-radius: 8px; +} +.gr-tab-nav button { + color: #ffffff !important; +} +.gr-tab-nav button.selected { + background-color: #34495e !important; +} +""" + +# --- Wound Classification Functions --- +def preprocess_input(img): + img = img.resize((224, 224)) + arr = keras_image.img_to_array(img) + arr = arr / 255.0 + return np.expand_dims(arr, axis=0) + +def get_reasoning_from_gemini(img, prediction): + try: + # For now, return a simple explanation without Gemini API to avoid typing issues + # In production, you would implement the proper Gemini API call here + explanations = { + "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", + "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", + "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", + "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", + "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." + } + + return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") + + except Exception as e: + return f"(Reasoning unavailable: {str(e)})" + +@spaces.GPU +def classify_wound_image(img): + if img is None: + return "
    No image provided
    ", "" + + img_array = preprocess_input(img) + predictions = wound_model.predict(img_array, verbose=0)[0] + pred_idx = int(np.argmax(predictions)) + pred_class = class_labels[pred_idx] + + # Get reasoning from Gemini + reasoning_text = get_reasoning_from_gemini(img, pred_class) + + # Prediction Card + predicted_card = f""" +
    +
    + Predicted Wound Type +
    +
    + {pred_class} +
    +
    + """ + + # Reasoning Card + reasoning_card = f""" +
    +
    + Reasoning +
    +
    + {reasoning_text} +
    +
    + """ + + return predicted_card, reasoning_card + +# --- Wound Severity Estimation Functions --- +@spaces.GPU +def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5): + """Compute area statistics for different depth regions""" + pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 + + # Extract only wound region + wound_mask = (mask > 127) + wound_depths = depth_map[wound_mask] + total_area = np.sum(wound_mask) * pixel_area_cm2 + + # Categorize depth regions + shallow = wound_depths < 3 + moderate = (wound_depths >= 3) & (wound_depths < 6) + deep = wound_depths >= 6 + + shallow_area = np.sum(shallow) * pixel_area_cm2 + moderate_area = np.sum(moderate) * pixel_area_cm2 + deep_area = np.sum(deep) * pixel_area_cm2 + + deep_ratio = deep_area / total_area if total_area > 0 else 0 + + return { + 'total_area_cm2': total_area, + 'shallow_area_cm2': shallow_area, + 'moderate_area_cm2': moderate_area, + 'deep_area_cm2': deep_area, + 'deep_ratio': deep_ratio, + 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0 + } + +def classify_wound_severity_by_area(depth_stats): + """Classify wound severity based on area and depth distribution""" + total = depth_stats['total_area_cm2'] + deep = depth_stats['deep_area_cm2'] + moderate = depth_stats['moderate_area_cm2'] + + if total == 0: + return "Unknown" + + # Severity classification rules + if deep > 2 or (deep / total) > 0.3: + return "Severe" + elif moderate > 1.5 or (moderate / total) > 0.4: + return "Moderate" + else: + return "Mild" + +def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5): + """Analyze wound severity from depth map and wound mask""" + if image is None or depth_map is None or wound_mask is None: + return "❌ Please upload image, depth map, and wound mask." + + # Convert wound mask to grayscale if needed + if len(wound_mask.shape) == 3: + wound_mask = np.mean(wound_mask, axis=2) + + # Ensure depth map and mask have same dimensions + if depth_map.shape[:2] != wound_mask.shape[:2]: + # Resize mask to match depth map + from PIL import Image + mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) + mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) + wound_mask = np.array(mask_pil) + + # Compute statistics + stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm) + severity = classify_wound_severity_by_area(stats) + + # Create severity report with color coding + severity_color = { + "Mild": "#4CAF50", # Green + "Moderate": "#FF9800", # Orange + "Severe": "#F44336" # Red + }.get(severity, "#9E9E9E") # Gray for unknown + + report = f""" +
    +
    + 🩹 Wound Severity Analysis +
    + +
    +
    +
    + 📏 Area Measurements +
    +
    +
    🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
    +
    🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
    +
    🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
    +
    🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
    +
    +
    + +
    +
    + 📊 Depth Analysis +
    +
    +
    🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
    +
    📏 Max Depth: {stats['max_depth']:.1f} mm
    +
    Pixel Spacing: {pixel_spacing_mm} mm
    +
    +
    +
    + +
    +
    + 🎯 Predicted Severity: {severity} +
    +
    + {get_severity_description(severity)} +
    +
    +
    + """ + + return report + +def get_severity_description(severity): + """Get description for severity level""" + descriptions = { + "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.", + "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.", + "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.", + "Unknown": "Unable to determine severity due to insufficient data." + } + return descriptions.get(severity, "Severity assessment unavailable.") + +def create_sample_wound_mask(image_shape, center=None, radius=50): + """Create a sample circular wound mask for testing""" + if center is None: + center = (image_shape[1] // 2, image_shape[0] // 2) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + y, x = np.ogrid[:image_shape[0], :image_shape[1]] + + # Create circular mask + dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) + mask[dist_from_center <= radius] = 255 + + return mask + +def create_realistic_wound_mask(image_shape, method='elliptical'): + """Create a more realistic wound mask with irregular shapes""" + h, w = image_shape[:2] + mask = np.zeros((h, w), dtype=np.uint8) + + if method == 'elliptical': + # Create elliptical wound mask + center = (w // 2, h // 2) + radius_x = min(w, h) // 3 + radius_y = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + # Add some irregularity to make it more realistic + ellipse = ((x - center[0])**2 / (radius_x**2) + + (y - center[1])**2 / (radius_y**2)) <= 1 + + # Add some noise and irregularity + noise = np.random.random((h, w)) > 0.8 + mask = (ellipse | noise).astype(np.uint8) * 255 + + elif method == 'irregular': + # Create irregular wound mask + center = (w // 2, h // 2) + radius = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius + + # Add irregular extensions + extensions = np.zeros_like(base_circle) + for i in range(3): + angle = i * 2 * np.pi / 3 + ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) + ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) + ext_radius = radius // 3 + + ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius + extensions = extensions | ext_circle + + mask = (base_circle | extensions).astype(np.uint8) * 255 + + # Apply morphological operations to smooth the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask + +# --- Depth Estimation Functions --- +@spaces.GPU +def predict_depth(image): + return depth_model.infer_image(image) + +def calculate_max_points(image): + """Calculate maximum points based on image dimensions (3x pixel count)""" + if image is None: + return 10000 # Default value + h, w = image.shape[:2] + max_points = h * w * 3 + # Ensure minimum and reasonable maximum values + return max(1000, min(max_points, 300000)) + +def update_slider_on_image_upload(image): + """Update the points slider when an image is uploaded""" + max_points = calculate_max_points(image) + default_value = min(10000, max_points // 10) # 10% of max points as default + return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, + label=f"Number of 3D points (max: {max_points:,})") + +@spaces.GPU +def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): + """Create a point cloud from depth map using camera intrinsics with high detail""" + h, w = depth_map.shape + + # Use smaller step for higher detail (reduced downsampling) + step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + x_cam = (x_coords - w / 2) / focal_length_x + y_cam = (y_coords - h / 2) / focal_length_y + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors = image_colors.reshape(-1, 3) / 255.0 + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + +@spaces.GPU +def reconstruct_surface_mesh_from_point_cloud(pcd): + """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" + # Estimate and orient normals with high precision + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) + pcd.orient_normals_consistent_tangent_plane(k=50) + + # Create surface mesh with maximum detail (depth=12 for very high resolution) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) + + # Return mesh without filtering low-density vertices + return mesh + +@spaces.GPU +def create_enhanced_3d_visualization(image, depth_map, max_points=10000): + """Create an enhanced 3D visualization using proper camera projection""" + h, w = depth_map.shape + + # Downsample to avoid too many points for performance + step = max(1, int(np.sqrt(h * w / max_points))) + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + focal_length = 470.4 # Default focal length + x_cam = (x_coords - w / 2) / focal_length + y_cam = (y_coords - h / 2) / focal_length + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + x_flat = x_3d.flatten() + y_flat = y_3d.flatten() + z_flat = z_3d.flatten() + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors_flat = image_colors.reshape(-1, 3) + + # Create 3D scatter plot with proper camera projection + fig = go.Figure(data=[go.Scatter3d( + x=x_flat, + y=y_flat, + z=z_flat, + mode='markers', + marker=dict( + size=1.5, + color=colors_flat, + opacity=0.9 + ), + hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + + 'Depth: %{z:.2f}
    ' + + '' + )]) + + fig.update_layout( + title="3D Point Cloud Visualization (Camera Projection)", + scene=dict( + xaxis_title="X (meters)", + yaxis_title="Y (meters)", + zaxis_title="Z (meters)", + camera=dict( + eye=dict(x=2.0, y=2.0, z=2.0), + center=dict(x=0, y=0, z=0), + up=dict(x=0, y=0, z=1) + ), + aspectmode='data' + ), + width=700, + height=600 + ) + + return fig + +def on_depth_submit(image, num_points, focal_x, focal_y): + original_image = image.copy() + + h, w = image.shape[:2] + + # Predict depth using the model + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + + # Save raw 16-bit depth + raw_depth = Image.fromarray(depth.astype('uint16')) + tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + raw_depth.save(tmp_raw_depth.name) + + # Normalize and convert to grayscale for display + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + norm_depth = norm_depth.astype(np.uint8) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) + + gray_depth = Image.fromarray(norm_depth) + tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + gray_depth.save(tmp_gray_depth.name) + + # Create point cloud + pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) + + # Reconstruct mesh from point cloud + mesh = reconstruct_surface_mesh_from_point_cloud(pcd) + + # Save mesh with faces as .ply + tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) + o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) + + # Create enhanced 3D scatter plot visualization + depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) + + return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] + +# --- Actual Wound Segmentation Functions --- +def create_automatic_wound_mask(image, method='deep_learning'): + """ + Automatically generate wound mask from image using the actual deep learning model + + Args: + image: Input image (numpy array) + method: Segmentation method (currently only 'deep_learning' supported) + + Returns: + mask: Binary wound mask + """ + if image is None: + return None + + # Use the actual deep learning model for segmentation + if method == 'deep_learning': + mask, _ = segmentation_model.segment_wound(image) + return mask + else: + # Fallback to deep learning if method not recognized + mask, _ = segmentation_model.segment_wound(image) + return mask + +def post_process_wound_mask(mask, min_area=100): + """Post-process the wound mask to remove noise and small objects""" + if mask is None: + return None + + # Convert to binary if needed + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + + # Apply morphological operations to clean up + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Remove small objects using OpenCV + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + + for contour in contours: + area = cv2.contourArea(contour) + if area >= min_area: + cv2.fillPoly(mask_clean, [contour], 255) + + # Fill holes + mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) + + return mask_clean + +def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): + """Analyze wound severity with automatic mask generation using actual segmentation model""" + if image is None or depth_map is None: + return "❌ Please provide both image and depth map." + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method=segmentation_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." + + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) + +# --- Main Gradio Interface --- +with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: + gr.HTML("

    Wound Analysis & Depth Estimation System

    ") + gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") + + # Shared image state + shared_image = gr.State() + + with gr.Tabs(): + # Tab 1: Wound Classification + with gr.Tab("1. Wound Classification"): + gr.Markdown("### Step 1: Upload and classify your wound image") + gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") + + with gr.Row(): + with gr.Column(scale=1): + wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) + + with gr.Column(scale=1): + wound_prediction_box = gr.HTML() + wound_reasoning_box = gr.HTML() + + # Button to pass image to depth estimation + with gr.Row(): + pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg") + pass_status = gr.HTML("") + + wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, + outputs=[wound_prediction_box, wound_reasoning_box]) + + # Store image when uploaded for classification + wound_image_input.change( + fn=lambda img: img, + inputs=[wound_image_input], + outputs=[shared_image] + ) + + # Tab 2: Depth Estimation + with gr.Tab("2. Depth Estimation & 3D Visualization"): + gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") + gr.Markdown("This module creates depth maps and 3D point clouds from your images.") + + with gr.Row(): + depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') + + with gr.Row(): + depth_submit = gr.Button(value="Compute Depth", variant="primary") + load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary") + points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, + label="Number of 3D points (upload image to update max)") + + with gr.Row(): + focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length X (pixels)") + focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length Y (pixels)") + + with gr.Row(): + gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") + raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") + point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") + + # 3D Visualization + gr.Markdown("### 3D Point Cloud Visualization") + gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") + depth_3d_plot = gr.Plot(label="3D Point Cloud") + + # Store depth map for severity analysis + depth_map_state = gr.State() + + # Tab 3: Wound Severity Analysis + with gr.Tab("3. 🩹 Wound Severity Analysis"): + gr.Markdown("### Step 3: Analyze wound severity using depth maps") + gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") + + with gr.Row(): + severity_input_image = gr.Image(label="Original Image", type='numpy') + severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') + + with gr.Row(): + wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') + severity_output = gr.HTML(label="Severity Analysis Report") + + gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") + + with gr.Row(): + auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") + manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg") + pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, + label="Pixel Spacing (mm/pixel)") + + gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") + + with gr.Row(): + # Load depth map from previous tab + load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") + + gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") + + # Update slider when image is uploaded + depth_input_image.change( + fn=update_slider_on_image_upload, + inputs=[depth_input_image], + outputs=[points_slider] + ) + + # Modified depth submit function to store depth map + def on_depth_submit_with_state(image, num_points, focal_x, focal_y): + results = on_depth_submit(image, num_points, focal_x, focal_y) + # Extract depth map from results for severity analysis + depth_map = None + if image is not None: + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + # Normalize depth for severity analysis + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth_map = norm_depth.astype(np.uint8) + return results + [depth_map] + + depth_submit.click(on_depth_submit_with_state, + inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], + outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) + + # Load depth map to severity tab and auto-generate mask + def load_depth_to_severity(depth_map, original_image): + if depth_map is None: + return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first." + + # Auto-generate wound mask using segmentation model + if original_image is not None: + auto_mask, _ = segmentation_model.segment_wound(original_image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!" + else: + return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded successfully!" + + load_depth_btn.click( + fn=load_depth_to_severity, + inputs=[depth_map_state, depth_input_image], + outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] + ) + + # Automatic severity analysis function + def run_auto_severity_analysis(image, depth_map, pixel_spacing): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method='deep_learning') + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." + + # Post-process the mask with fixed minimum area + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing) + + # Manual severity analysis function + def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + if wound_mask is None: + return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)." + + return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing) + + # Connect event handlers + auto_severity_button.click( + fn=run_auto_severity_analysis, + inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider], + outputs=[severity_output] + ) + + manual_severity_button.click( + fn=run_manual_severity_analysis, + inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], + outputs=[severity_output] + ) + + + + # Auto-generate mask when image is uploaded + def auto_generate_mask_on_image_upload(image): + if image is None: + return None, "❌ No image uploaded." + + # Generate automatic wound mask using segmentation model + auto_mask, _ = segmentation_model.segment_wound(image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return processed_mask, "✅ Wound mask auto-generated using deep learning model!" + else: + return None, "✅ Image uploaded but no wound detected. Try uploading a different image." + else: + return None, "✅ Image uploaded but segmentation failed. Try uploading a different image." + + # Load shared image from classification tab + def load_shared_image(shared_img): + if shared_img is None: + return gr.Image(), "❌ No image available from classification tab" + + # Convert PIL image to numpy array for depth estimation + if hasattr(shared_img, 'convert'): + # It's a PIL image, convert to numpy + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab" + else: + # Already numpy array + return shared_img, "✅ Image loaded from classification tab" + + # Auto-generate mask when image is uploaded to severity tab + severity_input_image.change( + fn=auto_generate_mask_on_image_upload, + inputs=[severity_input_image], + outputs=[wound_mask_input, gr.HTML()] + ) + + load_shared_btn.click( + fn=load_shared_image, + inputs=[shared_image], + outputs=[depth_input_image, gr.HTML()] + ) + + # Pass image to depth tab function + def pass_image_to_depth(img): + if img is None: + return "❌ No image uploaded in classification tab" + return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" + + pass_to_depth_btn.click( + fn=pass_image_to_depth, + inputs=[shared_image], + outputs=[pass_status] + ) + +if __name__ == '__main__': + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True + ) \ No newline at end of file diff --git a/temp_files/README.md b/temp_files/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8b38f5541ad8472e1343f988edda78b9491014af --- /dev/null +++ b/temp_files/README.md @@ -0,0 +1,12 @@ +--- +title: Wound Analysis V22 +emoji: 📉 +colorFrom: purple +colorTo: green +sdk: gradio +sdk_version: 5.41.1 +app_file: app.py +pinned: false +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/temp_files/fw2.txt b/temp_files/fw2.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9835bcb758f150c00ec9dcf322a5e92f0923c81 --- /dev/null +++ b/temp_files/fw2.txt @@ -0,0 +1,1175 @@ +import glob +import gradio as gr +import matplotlib +import numpy as np +from PIL import Image +import torch +import tempfile +from gradio_imageslider import ImageSlider +import plotly.graph_objects as go +import plotly.express as px +import open3d as o3d +from depth_anything_v2.dpt import DepthAnythingV2 +import os +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.keras.preprocessing import image as keras_image +import base64 +from io import BytesIO +import gdown +import spaces +import cv2 + +# Import actual segmentation model components +from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling +from utils.learning.metrics import dice_coef, precision, recall +from utils.io.data import normalize + +# Define path and file ID +checkpoint_dir = "checkpoints" +os.makedirs(checkpoint_dir, exist_ok=True) + +model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") +gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" + +# Download if not already present +if not os.path.exists(model_file): + print("Downloading model from Google Drive...") + gdown.download(gdrive_url, model_file, quiet=False) + +# --- TensorFlow: Check GPU Availability --- +gpus = tf.config.list_physical_devices('GPU') +if gpus: + print("TensorFlow is using GPU") +else: + print("TensorFlow is using CPU") + +# --- Load Wound Classification Model and Class Labels --- +wound_model = load_model("keras_model.h5") +with open("labels.txt", "r") as f: + class_labels = [line.strip().split(maxsplit=1)[1] for line in f] + +# --- Load Actual Wound Segmentation Model --- +class WoundSegmentationModel: + def __init__(self): + self.input_dim_x = 224 + self.input_dim_y = 224 + self.model = None + self.load_model() + + def load_model(self): + """Load the trained wound segmentation model""" + try: + # Try to load the most recent model + weight_file_name = '2025-08-07_16-25-27.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e: + print(f"Error loading segmentation model: {e}") + # Fallback to the older model + try: + weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Segmentation model loaded successfully from {model_path}") + except Exception as e2: + print(f"Error loading fallback segmentation model: {e2}") + self.model = None + + def preprocess_image(self, image): + """Preprocess the uploaded image for model input""" + if image is None: + return None + + # Convert to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + # Convert BGR to RGB if needed + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Resize to model input size + image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) + + # Normalize the image + image = image.astype(np.float32) / 255.0 + + # Add batch dimension + image = np.expand_dims(image, axis=0) + + return image + + def postprocess_prediction(self, prediction): + """Postprocess the model prediction""" + # Remove batch dimension + prediction = prediction[0] + + # Apply threshold to get binary mask + threshold = 0.5 + binary_mask = (prediction > threshold).astype(np.uint8) * 255 + + return binary_mask + + def segment_wound(self, input_image): + """Main function to segment wound from uploaded image""" + if self.model is None: + return None, "Error: Segmentation model not loaded. Please check the model files." + + if input_image is None: + return None, "Please upload an image." + + try: + # Preprocess the image + processed_image = self.preprocess_image(input_image) + + if processed_image is None: + return None, "Error processing image." + + # Make prediction + prediction = self.model.predict(processed_image, verbose=0) + + # Postprocess the prediction + segmented_mask = self.postprocess_prediction(prediction) + + return segmented_mask, "Segmentation completed successfully!" + + except Exception as e: + return None, f"Error during segmentation: {str(e)}" + +# Initialize the segmentation model +segmentation_model = WoundSegmentationModel() + +# --- PyTorch: Set Device and Load Depth Model --- +map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") +print(f"Using PyTorch device: {map_device}") + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} +encoder = 'vitl' +depth_model = DepthAnythingV2(**model_configs[encoder]) +state_dict = torch.load( + f'checkpoints/depth_anything_v2_{encoder}.pth', + map_location=map_device +) +depth_model.load_state_dict(state_dict) +depth_model = depth_model.to(map_device).eval() + + +# --- Custom CSS for unified dark theme --- +css = """ +.gradio-container { + font-family: 'Segoe UI', sans-serif; + background-color: #121212; + color: #ffffff; + padding: 20px; +} +.gr-button { + background-color: #2c3e50; + color: white; + border-radius: 10px; +} +.gr-button:hover { + background-color: #34495e; +} +.gr-html, .gr-html div { + white-space: normal !important; + overflow: visible !important; + text-overflow: unset !important; + word-break: break-word !important; +} +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 80vh; +} +#img-display-output { + max-height: 80vh; +} +#download { + height: 62px; +} +h1 { + text-align: center; + font-size: 3rem; + font-weight: bold; + margin: 2rem 0; + color: #ffffff; +} +h2 { + color: #ffffff; + text-align: center; + margin: 1rem 0; +} +.gr-tabs { + background-color: #1e1e1e; + border-radius: 10px; + padding: 10px; +} +.gr-tab-nav { + background-color: #2c3e50; + border-radius: 8px; +} +.gr-tab-nav button { + color: #ffffff !important; +} +.gr-tab-nav button.selected { + background-color: #34495e !important; +} +""" + +# --- Wound Classification Functions --- +def preprocess_input(img): + img = img.resize((224, 224)) + arr = keras_image.img_to_array(img) + arr = arr / 255.0 + return np.expand_dims(arr, axis=0) + +def get_reasoning_from_gemini(img, prediction): + try: + # For now, return a simple explanation without Gemini API to avoid typing issues + # In production, you would implement the proper Gemini API call here + explanations = { + "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", + "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", + "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", + "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", + "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." + } + + return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") + + except Exception as e: + return f"(Reasoning unavailable: {str(e)})" + +@spaces.GPU +def classify_wound_image(img): + if img is None: + return "
    No image provided
    ", "" + + img_array = preprocess_input(img) + predictions = wound_model.predict(img_array, verbose=0)[0] + pred_idx = int(np.argmax(predictions)) + pred_class = class_labels[pred_idx] + + # Get reasoning from Gemini + reasoning_text = get_reasoning_from_gemini(img, pred_class) + + # Prediction Card + predicted_card = f""" +
    +
    + Predicted Wound Type +
    +
    + {pred_class} +
    +
    + """ + + # Reasoning Card + reasoning_card = f""" +
    +
    + Reasoning +
    +
    + {reasoning_text} +
    +
    + """ + + return predicted_card, reasoning_card + +# --- Enhanced Wound Severity Estimation Functions --- +@spaces.GPU +def compute_enhanced_depth_statistics(depth_map, mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): + """ + Enhanced depth analysis with proper calibration and medical standards + Based on wound depth classification standards: + - Superficial: 0-2mm (epidermis only) + - Partial thickness: 2-4mm (epidermis + partial dermis) + - Full thickness: 4-6mm (epidermis + full dermis) + - Deep: >6mm (involving subcutaneous tissue) + """ + # Convert pixel spacing to mm + pixel_spacing_mm = float(pixel_spacing_mm) + + # Calculate pixel area in cm² + pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 + + # Extract wound region (binary mask) + wound_mask = (mask > 127).astype(np.uint8) + + # Apply morphological operations to clean the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + wound_mask = cv2.morphologyEx(wound_mask, cv2.MORPH_CLOSE, kernel) + + # Get depth values only for wound region + wound_depths = depth_map[wound_mask > 0] + + if len(wound_depths) == 0: + return { + 'total_area_cm2': 0, + 'superficial_area_cm2': 0, + 'partial_thickness_area_cm2': 0, + 'full_thickness_area_cm2': 0, + 'deep_area_cm2': 0, + 'mean_depth_mm': 0, + 'max_depth_mm': 0, + 'depth_std_mm': 0, + 'deep_ratio': 0, + 'wound_volume_cm3': 0, + 'depth_percentiles': {'25': 0, '50': 0, '75': 0} + } + + # Calibrate depth map for more accurate measurements + calibrated_depth_map = calibrate_depth_map(depth_map, reference_depth_mm=depth_calibration_mm) + + # Get calibrated depth values for wound region + wound_depths_mm = calibrated_depth_map[wound_mask > 0] + + # Medical depth classification + superficial_mask = wound_depths_mm < 2.0 + partial_thickness_mask = (wound_depths_mm >= 2.0) & (wound_depths_mm < 4.0) + full_thickness_mask = (wound_depths_mm >= 4.0) & (wound_depths_mm < 6.0) + deep_mask = wound_depths_mm >= 6.0 + + # Calculate areas + total_pixels = np.sum(wound_mask > 0) + total_area_cm2 = total_pixels * pixel_area_cm2 + + superficial_area_cm2 = np.sum(superficial_mask) * pixel_area_cm2 + partial_thickness_area_cm2 = np.sum(partial_thickness_mask) * pixel_area_cm2 + full_thickness_area_cm2 = np.sum(full_thickness_mask) * pixel_area_cm2 + deep_area_cm2 = np.sum(deep_mask) * pixel_area_cm2 + + # Calculate depth statistics + mean_depth_mm = np.mean(wound_depths_mm) + max_depth_mm = np.max(wound_depths_mm) + depth_std_mm = np.std(wound_depths_mm) + + # Calculate depth percentiles + depth_percentiles = { + '25': np.percentile(wound_depths_mm, 25), + '50': np.percentile(wound_depths_mm, 50), + '75': np.percentile(wound_depths_mm, 75) + } + + # Calculate wound volume (approximate) + # Volume = area * average depth + wound_volume_cm3 = total_area_cm2 * (mean_depth_mm / 10.0) + + # Deep tissue ratio + deep_ratio = deep_area_cm2 / total_area_cm2 if total_area_cm2 > 0 else 0 + + # Calculate analysis quality metrics + wound_pixel_count = len(wound_depths_mm) + analysis_quality = "High" if wound_pixel_count > 1000 else "Medium" if wound_pixel_count > 500 else "Low" + + # Calculate depth consistency (lower std dev = more consistent) + depth_consistency = "High" if depth_std_mm < 2.0 else "Medium" if depth_std_mm < 4.0 else "Low" + + return { + 'total_area_cm2': total_area_cm2, + 'superficial_area_cm2': superficial_area_cm2, + 'partial_thickness_area_cm2': partial_thickness_area_cm2, + 'full_thickness_area_cm2': full_thickness_area_cm2, + 'deep_area_cm2': deep_area_cm2, + 'mean_depth_mm': mean_depth_mm, + 'max_depth_mm': max_depth_mm, + 'depth_std_mm': depth_std_mm, + 'deep_ratio': deep_ratio, + 'wound_volume_cm3': wound_volume_cm3, + 'depth_percentiles': depth_percentiles, + 'analysis_quality': analysis_quality, + 'depth_consistency': depth_consistency, + 'wound_pixel_count': wound_pixel_count + } + +def classify_wound_severity_by_enhanced_metrics(depth_stats): + """ + Enhanced wound severity classification based on medical standards + Uses multiple criteria: depth, area, volume, and tissue involvement + """ + if depth_stats['total_area_cm2'] == 0: + return "Unknown" + + # Extract key metrics + total_area = depth_stats['total_area_cm2'] + deep_area = depth_stats['deep_area_cm2'] + full_thickness_area = depth_stats['full_thickness_area_cm2'] + mean_depth = depth_stats['mean_depth_mm'] + max_depth = depth_stats['max_depth_mm'] + wound_volume = depth_stats['wound_volume_cm3'] + deep_ratio = depth_stats['deep_ratio'] + + # Medical severity classification criteria + severity_score = 0 + + # Criterion 1: Maximum depth + if max_depth >= 10.0: + severity_score += 3 # Very severe + elif max_depth >= 6.0: + severity_score += 2 # Severe + elif max_depth >= 4.0: + severity_score += 1 # Moderate + + # Criterion 2: Mean depth + if mean_depth >= 5.0: + severity_score += 2 + elif mean_depth >= 3.0: + severity_score += 1 + + # Criterion 3: Deep tissue involvement ratio + if deep_ratio >= 0.5: + severity_score += 3 # More than 50% deep tissue + elif deep_ratio >= 0.25: + severity_score += 2 # 25-50% deep tissue + elif deep_ratio >= 0.1: + severity_score += 1 # 10-25% deep tissue + + # Criterion 4: Total wound area + if total_area >= 10.0: + severity_score += 2 # Large wound (>10 cm²) + elif total_area >= 5.0: + severity_score += 1 # Medium wound (5-10 cm²) + + # Criterion 5: Wound volume + if wound_volume >= 5.0: + severity_score += 2 # High volume + elif wound_volume >= 2.0: + severity_score += 1 # Medium volume + + # Determine severity based on total score + if severity_score >= 8: + return "Very Severe" + elif severity_score >= 6: + return "Severe" + elif severity_score >= 4: + return "Moderate" + elif severity_score >= 2: + return "Mild" + else: + return "Superficial" + +def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5, depth_calibration_mm=15.0): + """Enhanced wound severity analysis with medical-grade metrics""" + if image is None or depth_map is None or wound_mask is None: + return "❌ Please upload image, depth map, and wound mask." + + # Convert wound mask to grayscale if needed + if len(wound_mask.shape) == 3: + wound_mask = np.mean(wound_mask, axis=2) + + # Ensure depth map and mask have same dimensions + if depth_map.shape[:2] != wound_mask.shape[:2]: + # Resize mask to match depth map + from PIL import Image + mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) + mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) + wound_mask = np.array(mask_pil) + + # Compute enhanced statistics + stats = compute_enhanced_depth_statistics(depth_map, wound_mask, pixel_spacing_mm, depth_calibration_mm) + severity = classify_wound_severity_by_enhanced_metrics(stats) + + # Enhanced severity color coding + severity_color = { + "Superficial": "#4CAF50", # Green + "Mild": "#8BC34A", # Light Green + "Moderate": "#FF9800", # Orange + "Severe": "#F44336", # Red + "Very Severe": "#9C27B0" # Purple + }.get(severity, "#9E9E9E") # Gray for unknown + + # Create comprehensive medical report + report = f""" +
    +
    + 🩹 Enhanced Wound Severity Analysis +
    + +
    +
    +
    + 📏 Tissue Involvement Analysis +
    +
    +
    🟢 Superficial (0-2mm): {stats['superficial_area_cm2']:.2f} cm²
    +
    🟡 Partial Thickness (2-4mm): {stats['partial_thickness_area_cm2']:.2f} cm²
    +
    🟠 Full Thickness (4-6mm): {stats['full_thickness_area_cm2']:.2f} cm²
    +
    🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
    +
    📊 Total Area: {stats['total_area_cm2']:.2f} cm²
    +
    +
    + +
    +
    + 📊 Depth Statistics +
    +
    +
    📏 Mean Depth: {stats['mean_depth_mm']:.1f} mm
    +
    📐 Max Depth: {stats['max_depth_mm']:.1f} mm
    +
    📊 Depth Std Dev: {stats['depth_std_mm']:.1f} mm
    +
    📦 Wound Volume: {stats['wound_volume_cm3']:.2f} cm³
    +
    🔥 Deep Tissue Ratio: {stats['deep_ratio']*100:.1f}%
    +
    +
    +
    + +
    +
    + 📈 Depth Percentiles & Quality Metrics +
    +
    +
    +
    📊 25th Percentile: {stats['depth_percentiles']['25']:.1f} mm
    +
    📊 Median (50th): {stats['depth_percentiles']['50']:.1f} mm
    +
    📊 75th Percentile: {stats['depth_percentiles']['75']:.1f} mm
    +
    +
    +
    🔍 Analysis Quality: {stats['analysis_quality']}
    +
    📏 Depth Consistency: {stats['depth_consistency']}
    +
    📊 Data Points: {stats['wound_pixel_count']:,}
    +
    +
    +
    + +
    +
    + 🎯 Medical Severity Assessment: {severity} +
    +
    + {get_enhanced_severity_description(severity)} +
    +
    +
    + """ + + return report + +def calibrate_depth_map(depth_map, reference_depth_mm=10.0): + """ + Calibrate depth map to real-world measurements using reference depth + This helps convert normalized depth values to actual millimeters + """ + if depth_map is None: + return depth_map + + # Find the maximum depth value in the depth map + max_depth_value = np.max(depth_map) + min_depth_value = np.min(depth_map) + + if max_depth_value == min_depth_value: + return depth_map + + # Apply calibration to convert to millimeters + # Assuming the maximum depth in the map corresponds to reference_depth_mm + calibrated_depth = (depth_map - min_depth_value) / (max_depth_value - min_depth_value) * reference_depth_mm + + return calibrated_depth + +def get_enhanced_severity_description(severity): + """Get comprehensive medical description for severity level""" + descriptions = { + "Superficial": "Epidermis-only damage. Minimal tissue loss, typically heals within 1-2 weeks with basic wound care.", + "Mild": "Superficial to partial thickness wound. Limited tissue involvement, good healing potential with proper care.", + "Moderate": "Partial to full thickness involvement. Requires careful monitoring and may need advanced wound care techniques.", + "Severe": "Full thickness with deep tissue involvement. High risk of complications, requires immediate medical attention.", + "Very Severe": "Extensive deep tissue damage. Critical condition requiring immediate surgical intervention and specialized care.", + "Unknown": "Unable to determine severity due to insufficient data or poor image quality." + } + return descriptions.get(severity, "Severity assessment unavailable.") + +def create_sample_wound_mask(image_shape, center=None, radius=50): + """Create a sample circular wound mask for testing""" + if center is None: + center = (image_shape[1] // 2, image_shape[0] // 2) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + y, x = np.ogrid[:image_shape[0], :image_shape[1]] + + # Create circular mask + dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) + mask[dist_from_center <= radius] = 255 + + return mask + +def create_realistic_wound_mask(image_shape, method='elliptical'): + """Create a more realistic wound mask with irregular shapes""" + h, w = image_shape[:2] + mask = np.zeros((h, w), dtype=np.uint8) + + if method == 'elliptical': + # Create elliptical wound mask + center = (w // 2, h // 2) + radius_x = min(w, h) // 3 + radius_y = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + # Add some irregularity to make it more realistic + ellipse = ((x - center[0])**2 / (radius_x**2) + + (y - center[1])**2 / (radius_y**2)) <= 1 + + # Add some noise and irregularity + noise = np.random.random((h, w)) > 0.8 + mask = (ellipse | noise).astype(np.uint8) * 255 + + elif method == 'irregular': + # Create irregular wound mask + center = (w // 2, h // 2) + radius = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius + + # Add irregular extensions + extensions = np.zeros_like(base_circle) + for i in range(3): + angle = i * 2 * np.pi / 3 + ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) + ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) + ext_radius = radius // 3 + + ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius + extensions = extensions | ext_circle + + mask = (base_circle | extensions).astype(np.uint8) * 255 + + # Apply morphological operations to smooth the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask + +# --- Depth Estimation Functions --- +@spaces.GPU +def predict_depth(image): + return depth_model.infer_image(image) + +def calculate_max_points(image): + """Calculate maximum points based on image dimensions (3x pixel count)""" + if image is None: + return 10000 # Default value + h, w = image.shape[:2] + max_points = h * w * 3 + # Ensure minimum and reasonable maximum values + return max(1000, min(max_points, 300000)) + +def update_slider_on_image_upload(image): + """Update the points slider when an image is uploaded""" + max_points = calculate_max_points(image) + default_value = min(10000, max_points // 10) # 10% of max points as default + return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, + label=f"Number of 3D points (max: {max_points:,})") + +@spaces.GPU +def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): + """Create a point cloud from depth map using camera intrinsics with high detail""" + h, w = depth_map.shape + + # Use smaller step for higher detail (reduced downsampling) + step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + x_cam = (x_coords - w / 2) / focal_length_x + y_cam = (y_coords - h / 2) / focal_length_y + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors = image_colors.reshape(-1, 3) / 255.0 + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + +@spaces.GPU +def reconstruct_surface_mesh_from_point_cloud(pcd): + """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" + # Estimate and orient normals with high precision + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) + pcd.orient_normals_consistent_tangent_plane(k=50) + + # Create surface mesh with maximum detail (depth=12 for very high resolution) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) + + # Return mesh without filtering low-density vertices + return mesh + +@spaces.GPU +def create_enhanced_3d_visualization(image, depth_map, max_points=10000): + """Create an enhanced 3D visualization using proper camera projection""" + h, w = depth_map.shape + + # Downsample to avoid too many points for performance + step = max(1, int(np.sqrt(h * w / max_points))) + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + focal_length = 470.4 # Default focal length + x_cam = (x_coords - w / 2) / focal_length + y_cam = (y_coords - h / 2) / focal_length + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + x_flat = x_3d.flatten() + y_flat = y_3d.flatten() + z_flat = z_3d.flatten() + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors_flat = image_colors.reshape(-1, 3) + + # Create 3D scatter plot with proper camera projection + fig = go.Figure(data=[go.Scatter3d( + x=x_flat, + y=y_flat, + z=z_flat, + mode='markers', + marker=dict( + size=1.5, + color=colors_flat, + opacity=0.9 + ), + hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + + 'Depth: %{z:.2f}
    ' + + '' + )]) + + fig.update_layout( + title="3D Point Cloud Visualization (Camera Projection)", + scene=dict( + xaxis_title="X (meters)", + yaxis_title="Y (meters)", + zaxis_title="Z (meters)", + camera=dict( + eye=dict(x=2.0, y=2.0, z=2.0), + center=dict(x=0, y=0, z=0), + up=dict(x=0, y=0, z=1) + ), + aspectmode='data' + ), + width=700, + height=600 + ) + + return fig + +def on_depth_submit(image, num_points, focal_x, focal_y): + original_image = image.copy() + + h, w = image.shape[:2] + + # Predict depth using the model + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + + # Save raw 16-bit depth + raw_depth = Image.fromarray(depth.astype('uint16')) + tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + raw_depth.save(tmp_raw_depth.name) + + # Normalize and convert to grayscale for display + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + norm_depth = norm_depth.astype(np.uint8) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) + + gray_depth = Image.fromarray(norm_depth) + tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + gray_depth.save(tmp_gray_depth.name) + + # Create point cloud + pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) + + # Reconstruct mesh from point cloud + mesh = reconstruct_surface_mesh_from_point_cloud(pcd) + + # Save mesh with faces as .ply + tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) + o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) + + # Create enhanced 3D scatter plot visualization + depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) + + return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] + +# --- Actual Wound Segmentation Functions --- +def create_automatic_wound_mask(image, method='deep_learning'): + """ + Automatically generate wound mask from image using the actual deep learning model + + Args: + image: Input image (numpy array) + method: Segmentation method (currently only 'deep_learning' supported) + + Returns: + mask: Binary wound mask + """ + if image is None: + return None + + # Use the actual deep learning model for segmentation + if method == 'deep_learning': + mask, _ = segmentation_model.segment_wound(image) + return mask + else: + # Fallback to deep learning if method not recognized + mask, _ = segmentation_model.segment_wound(image) + return mask + +def post_process_wound_mask(mask, min_area=100): + """Post-process the wound mask to remove noise and small objects""" + if mask is None: + return None + + # Convert to binary if needed + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + + # Apply morphological operations to clean up + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Remove small objects using OpenCV + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + + for contour in contours: + area = cv2.contourArea(contour) + if area >= min_area: + cv2.fillPoly(mask_clean, [contour], 255) + + # Fill holes + mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) + + return mask_clean + +def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='deep_learning'): + """Analyze wound severity with automatic mask generation using actual segmentation model""" + if image is None or depth_map is None: + return "❌ Please provide both image and depth map." + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method=segmentation_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." + + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) + +# --- Main Gradio Interface --- +with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: + gr.HTML("

    Wound Analysis & Depth Estimation System

    ") + gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") + + # Shared image state + shared_image = gr.State() + + with gr.Tabs(): + # Tab 1: Wound Classification + with gr.Tab("1. Wound Classification"): + gr.Markdown("### Step 1: Upload and classify your wound image") + gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") + + with gr.Row(): + with gr.Column(scale=1): + wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) + + with gr.Column(scale=1): + wound_prediction_box = gr.HTML() + wound_reasoning_box = gr.HTML() + + # Button to pass image to depth estimation + with gr.Row(): + pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg") + pass_status = gr.HTML("") + + wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, + outputs=[wound_prediction_box, wound_reasoning_box]) + + # Store image when uploaded for classification + wound_image_input.change( + fn=lambda img: img, + inputs=[wound_image_input], + outputs=[shared_image] + ) + + # Tab 2: Depth Estimation + with gr.Tab("2. Depth Estimation & 3D Visualization"): + gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") + gr.Markdown("This module creates depth maps and 3D point clouds from your images.") + + with gr.Row(): + depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') + + with gr.Row(): + depth_submit = gr.Button(value="Compute Depth", variant="primary") + load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary") + points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, + label="Number of 3D points (upload image to update max)") + + with gr.Row(): + focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length X (pixels)") + focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length Y (pixels)") + + with gr.Row(): + gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") + raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") + point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") + + # 3D Visualization + gr.Markdown("### 3D Point Cloud Visualization") + gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") + depth_3d_plot = gr.Plot(label="3D Point Cloud") + + # Store depth map for severity analysis + depth_map_state = gr.State() + + # Tab 3: Wound Severity Analysis + with gr.Tab("3. 🩹 Wound Severity Analysis"): + gr.Markdown("### Step 3: Analyze wound severity using depth maps") + gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") + + with gr.Row(): + severity_input_image = gr.Image(label="Original Image", type='numpy') + severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') + + with gr.Row(): + wound_mask_input = gr.Image(label="Auto-Generated Wound Mask", type='numpy') + severity_output = gr.HTML(label="Severity Analysis Report") + + gr.Markdown("**Note:** The deep learning segmentation model will automatically generate a wound mask when you upload an image or load a depth map.") + + with gr.Row(): + auto_severity_button = gr.Button("🤖 Analyze Severity with Auto-Generated Mask", variant="primary", size="lg") + manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg") + pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, + label="Pixel Spacing (mm/pixel)") + depth_calibration_slider = gr.Slider(minimum=5.0, maximum=30.0, value=15.0, step=1.0, + label="Depth Calibration (mm)", + info="Adjust based on expected maximum wound depth") + + gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") + gr.Markdown("**Depth Calibration:** Adjust the maximum expected wound depth to improve measurement accuracy. For shallow wounds use 5-10mm, for deep wounds use 15-30mm.") + + with gr.Row(): + # Load depth map from previous tab + load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") + + gr.Markdown("**Note:** When you load a depth map or upload an image, the segmentation model will automatically generate a wound mask.") + + # Update slider when image is uploaded + depth_input_image.change( + fn=update_slider_on_image_upload, + inputs=[depth_input_image], + outputs=[points_slider] + ) + + # Modified depth submit function to store depth map + def on_depth_submit_with_state(image, num_points, focal_x, focal_y): + results = on_depth_submit(image, num_points, focal_x, focal_y) + # Extract depth map from results for severity analysis + depth_map = None + if image is not None: + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + # Normalize depth for severity analysis + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth_map = norm_depth.astype(np.uint8) + return results + [depth_map] + + depth_submit.click(on_depth_submit_with_state, + inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], + outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) + + # Load depth map to severity tab and auto-generate mask + def load_depth_to_severity(depth_map, original_image): + if depth_map is None: + return None, None, None, "❌ No depth map available. Please compute depth in Tab 2 first." + + # Auto-generate wound mask using segmentation model + if original_image is not None: + auto_mask, _ = segmentation_model.segment_wound(original_image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return depth_map, original_image, processed_mask, "✅ Depth map loaded and wound mask auto-generated!" + else: + return depth_map, original_image, None, "✅ Depth map loaded but no wound detected. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded but segmentation failed. Try uploading a different image." + else: + return depth_map, original_image, None, "✅ Depth map loaded successfully!" + + load_depth_btn.click( + fn=load_depth_to_severity, + inputs=[depth_map_state, depth_input_image], + outputs=[severity_depth_map, severity_input_image, wound_mask_input, gr.HTML()] + ) + + # Automatic severity analysis function + def run_auto_severity_analysis(image, depth_map, pixel_spacing, depth_calibration): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + + # Generate automatic wound mask using the actual model + auto_mask = create_automatic_wound_mask(image, method='deep_learning') + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask. Please check if the segmentation model is loaded." + + # Post-process the mask with fixed minimum area + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected by the segmentation model. Try uploading a different image or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing, depth_calibration) + + # Manual severity analysis function + def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing, depth_calibration): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + if wound_mask is None: + return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)." + + return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing, depth_calibration) + + # Connect event handlers + auto_severity_button.click( + fn=run_auto_severity_analysis, + inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, depth_calibration_slider], + outputs=[severity_output] + ) + + manual_severity_button.click( + fn=run_manual_severity_analysis, + inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider, depth_calibration_slider], + outputs=[severity_output] + ) + + + + # Auto-generate mask when image is uploaded + def auto_generate_mask_on_image_upload(image): + if image is None: + return None, "❌ No image uploaded." + + # Generate automatic wound mask using segmentation model + auto_mask, _ = segmentation_model.segment_wound(image) + if auto_mask is not None: + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + if processed_mask is not None and np.sum(processed_mask > 0) > 0: + return processed_mask, "✅ Wound mask auto-generated using deep learning model!" + else: + return None, "✅ Image uploaded but no wound detected. Try uploading a different image." + else: + return None, "✅ Image uploaded but segmentation failed. Try uploading a different image." + + # Load shared image from classification tab + def load_shared_image(shared_img): + if shared_img is None: + return gr.Image(), "❌ No image available from classification tab" + + # Convert PIL image to numpy array for depth estimation + if hasattr(shared_img, 'convert'): + # It's a PIL image, convert to numpy + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab" + else: + # Already numpy array + return shared_img, "✅ Image loaded from classification tab" + + # Auto-generate mask when image is uploaded to severity tab + severity_input_image.change( + fn=auto_generate_mask_on_image_upload, + inputs=[severity_input_image], + outputs=[wound_mask_input, gr.HTML()] + ) + + load_shared_btn.click( + fn=load_shared_image, + inputs=[shared_image], + outputs=[depth_input_image, gr.HTML()] + ) + + # Pass image to depth tab function + def pass_image_to_depth(img): + if img is None: + return "❌ No image uploaded in classification tab" + return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" + + pass_to_depth_btn.click( + fn=pass_image_to_depth, + inputs=[shared_image], + outputs=[pass_status] + ) + +if __name__ == '__main__': + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True + ) \ No newline at end of file diff --git a/temp_files/predict.py b/temp_files/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a55295b68ea55691886e8ca57ad65cebf95bc08d --- /dev/null +++ b/temp_files/predict.py @@ -0,0 +1,64 @@ +import cv2 +from keras.models import load_model +from keras.utils.generic_utils import CustomObjectScope + +from models.unets import Unet2D +from models.deeplab import Deeplabv3, relu6, BilinearUpsampling, DepthwiseConv2D +from models.FCN import FCN_Vgg16_16s + +from utils.learning.metrics import dice_coef, precision, recall +from utils.BilinearUpSampling import BilinearUpSampling2D +from utils.io.data import load_data, save_results, save_rgb_results, save_history, load_test_images, DataGen + + +# settings +input_dim_x = 224 +input_dim_y = 224 +color_space = 'rgb' +path = './data/Medetec_foot_ulcer_224/' +weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' +pred_save_path = '2019-12-19 01%3A53%3A15.480800/' + +data_gen = DataGen(path, split_ratio=0.0, x=input_dim_x, y=input_dim_y, color_space=color_space) +x_test, test_label_filenames_list = load_test_images(path) + +# ### get unet model +# unet2d = Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3) +# model = unet2d.get_unet_model_yuanqing() +# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name +# , custom_objects={'recall':recall, +# 'precision':precision, +# 'dice_coef': dice_coef, +# 'relu6':relu6, +# 'DepthwiseConv2D':DepthwiseConv2D, +# 'BilinearUpsampling':BilinearUpsampling}) + +# ### get separable unet model +# sep_unet = Separable_Unet2D(n_filters=64, input_dim_x=input_dim_x, input_dim_y=input_dim_y, num_channels=3) +# model, model_name = sep_unet.get_sep_unet_v2() +# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name +# , custom_objects={'dice_coef': dice_coef, +# 'relu6':relu6, +# 'DepthwiseConv2D':DepthwiseConv2D, +# 'BilinearUpsampling':BilinearUpsampling}) + +# ### get VGG16 model +# model, model_name = FCN_Vgg16_16s(input_shape=(input_dim_x, input_dim_y, 3)) +# with CustomObjectScope({'BilinearUpSampling2D':BilinearUpSampling2D}): +# model = load_model('./azh_wound_care_center_diabetic_foot_training_history/' + weight_file_name +# , custom_objects={'dice_coef': dice_coef}) + +# ### get mobilenetv2 model +model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1) +model = load_model('./training_history/' + weight_file_name + , custom_objects={'recall':recall, + 'precision':precision, + 'dice_coef': dice_coef, + 'relu6':relu6, + 'DepthwiseConv2D':DepthwiseConv2D, + 'BilinearUpsampling':BilinearUpsampling}) + +for image_batch, label_batch in data_gen.generate_data(batch_size=len(x_test), test=True): + prediction = model.predict(image_batch, verbose=1) + save_results(prediction, 'rgb', path + 'test/predictions/' + pred_save_path, test_label_filenames_list) + break diff --git a/temp_files/requirements.txt b/temp_files/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d2f4d0195e0af4bb2dd200924865b738c3475c1 --- /dev/null +++ b/temp_files/requirements.txt @@ -0,0 +1,109 @@ +aiofiles +annotated-types +anyio +asttokens +attrs +blinker +certifi +charset-normalizer +click +colorama +comm +ConfigArgParse +contourpy +cycler +dash +decorator +executing +fastapi +fastjsonschema +ffmpy +filelock +Flask +fonttools +fsspec +gdown +gradio +gradio_client +gradio_imageslider +groovy +h11 +httpcore +httpx +huggingface-hub +idna +importlib_metadata +itsdangerous +jedi +Jinja2 +jsonschema +jsonschema-specifications +jupyter_core +jupyterlab_widgets +kiwisolver +markdown-it-py +MarkupSafe +matplotlib +matplotlib-inline +mdurl +mpmath +narwhals +nbformat +nest-asyncio +networkx +numpy<2 +open3d +opencv-python +orjson +packaging +pandas +parso +pillow +platformdirs +plotly +prompt_toolkit +pure_eval +pydantic_core +pydub +Pygments +pyparsing +python-dateutil +python-multipart +pytz +PyYAML +referencing +requests +retrying +rich +rpds-py +ruff +safehttpx +scikit-image +semantic-version +setuptools +shellingham +six +sniffio +stack-data +starlette +sympy +tensorflow<2.11 +tensorflow_hub +tomlkit +torch +torchvision +tqdm +traitlets +typer +typing-inspection +typing_extensions +tzdata +urllib3 +uvicorn +wcwidth +websockets +Werkzeug +wheel +widgetsnbextension +zipp +pydantic==2.10.6 \ No newline at end of file diff --git a/temp_files/run_gradio_app.py b/temp_files/run_gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3dec33e4b79363537cb8d5d9ae25a4dee9b9ba --- /dev/null +++ b/temp_files/run_gradio_app.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Simple launcher for the Wound Segmentation Gradio App +""" + +import sys +import os + +def check_dependencies(): + """Check if required dependencies are installed""" + required_packages = ['gradio', 'tensorflow', 'cv2', 'numpy'] + missing_packages = [] + + for package in required_packages: + try: + if package == 'cv2': + import cv2 + else: + __import__(package) + except ImportError: + missing_packages.append(package) + + if missing_packages: + print("❌ Missing required packages:") + for package in missing_packages: + print(f" - {package}") + print("\n📦 Install missing packages with:") + print(" pip install -r requirements.txt") + return False + + print("✅ All required packages are installed!") + return True + +def check_model_files(): + """Check if model files exist""" + model_files = [ + 'training_history/2025-08-07_12-30-43.hdf5', + 'training_history/2019-12-19 01%3A53%3A15.480800.hdf5' + ] + + existing_models = [] + for model_file in model_files: + if os.path.exists(model_file): + existing_models.append(model_file) + + if not existing_models: + print("❌ No model files found!") + print(" Please ensure you have trained models in the training_history/ directory") + return False + + print(f"✅ Found {len(existing_models)} model file(s):") + for model in existing_models: + print(f" - {model}") + return True + +def main(): + """Main function to launch the Gradio app""" + print("🚀 Starting Wound Segmentation Gradio App...") + print("=" * 50) + + # Check dependencies + if not check_dependencies(): + sys.exit(1) + + # Check model files + if not check_model_files(): + sys.exit(1) + + print("\n🎯 Launching Gradio interface...") + print(" The app will be available at: http://localhost:7860") + print(" Press Ctrl+C to stop the server") + print("=" * 50) + + try: + # Import and run the Gradio app + from gradio_app import create_gradio_interface + + interface = create_gradio_interface() + interface.launch( + server_name="0.0.0.0", + server_port=7860, + share=True, + show_error=True + ) + except KeyboardInterrupt: + print("\n👋 Gradio app stopped by user") + except Exception as e: + print(f"\n❌ Error launching Gradio app: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/temp_files/segmentation_app.py b/temp_files/segmentation_app.py new file mode 100644 index 0000000000000000000000000000000000000000..4e690f1a117d637234f7d6487138f3e4b584f5f9 --- /dev/null +++ b/temp_files/segmentation_app.py @@ -0,0 +1,222 @@ +import gradio as gr +import cv2 +import numpy as np +import tensorflow as tf +from tensorflow import keras +from keras.models import load_model +from keras.utils.generic_utils import CustomObjectScope + +# Import custom modules +from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling +from utils.learning.metrics import dice_coef, precision, recall +from utils.io.data import normalize + +class WoundSegmentationApp: + def __init__(self): + self.input_dim_x = 224 + self.input_dim_y = 224 + self.model = None + self.load_model() + + def load_model(self): + """Load the trained wound segmentation model""" + try: + # Load the model with custom objects + weight_file_name = '2025-08-07_12-30-43.hdf5' # Use the most recent model + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Model loaded successfully from {model_path}") + except Exception as e: + print(f"Error loading model: {e}") + # Fallback to the older model if the newer one fails + try: + weight_file_name = '2019-12-19 01%3A53%3A15.480800.hdf5' + model_path = f'./training_history/{weight_file_name}' + + self.model = load_model(model_path, + custom_objects={ + 'recall': recall, + 'precision': precision, + 'dice_coef': dice_coef, + 'relu6': relu6, + 'DepthwiseConv2D': DepthwiseConv2D, + 'BilinearUpsampling': BilinearUpsampling + }) + print(f"Model loaded successfully from {model_path}") + except Exception as e2: + print(f"Error loading fallback model: {e2}") + self.model = None + + def preprocess_image(self, image): + """Preprocess the uploaded image for model input""" + if image is None: + return None + + # Convert to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + # Convert BGR to RGB if needed + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Resize to model input size + image = cv2.resize(image, (self.input_dim_x, self.input_dim_y)) + + # Normalize the image + image = image.astype(np.float32) / 255.0 + + # Add batch dimension + image = np.expand_dims(image, axis=0) + + return image + + def postprocess_prediction(self, prediction): + """Postprocess the model prediction""" + # Remove batch dimension + prediction = prediction[0] + + # Apply threshold to get binary mask + threshold = 0.5 + binary_mask = (prediction > threshold).astype(np.uint8) * 255 + + # Convert to 3-channel image for visualization + mask_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2RGB) + + return mask_rgb + + def segment_wound(self, input_image): + """Main function to segment wound from uploaded image""" + if self.model is None: + return None, "Error: Model not loaded. Please check the model files." + + if input_image is None: + return None, "Please upload an image." + + try: + # Preprocess the image + processed_image = self.preprocess_image(input_image) + + if processed_image is None: + return None, "Error processing image." + + # Make prediction + prediction = self.model.predict(processed_image, verbose=0) + + # Postprocess the prediction + segmented_mask = self.postprocess_prediction(prediction) + + # Create overlay image (original image with segmentation overlay) + original_resized = cv2.resize(input_image, (self.input_dim_x, self.input_dim_y)) + if len(original_resized.shape) == 3: + original_resized = cv2.cvtColor(original_resized, cv2.COLOR_RGB2BGR) + + # Create overlay with red segmentation + overlay = original_resized.copy() + mask_red = np.zeros_like(original_resized) + mask_red[:, :, 2] = segmented_mask[:, :, 0] # Red channel + + # Blend overlay with original image + alpha = 0.6 + overlay = cv2.addWeighted(overlay, 1-alpha, mask_red, alpha, 0) + + return segmented_mask, overlay + + except Exception as e: + return None, f"Error during segmentation: {str(e)}" + +def create_gradio_interface(): + """Create and return the Gradio interface""" + + # Initialize the app + app = WoundSegmentationApp() + + # Define the interface + with gr.Blocks(title="Wound Segmentation Tool", theme=gr.themes.Soft()) as interface: + gr.Markdown( + """ + # 🩹 Wound Segmentation Tool + + Upload an image of a wound to get an automated segmentation mask. + The model will identify and highlight the wound area in the image. + + **Instructions:** + 1. Upload an image of a wound + 2. Click "Segment Wound" to process the image + 3. View the segmentation mask and overlay results + """ + ) + + with gr.Row(): + with gr.Column(): + input_image = gr.Image( + label="Upload Wound Image", + type="numpy", + height=400 + ) + + segment_btn = gr.Button( + "🔍 Segment Wound", + variant="primary", + size="lg" + ) + + with gr.Column(): + mask_output = gr.Image( + label="Segmentation Mask", + height=400 + ) + + overlay_output = gr.Image( + label="Overlay Result", + height=400 + ) + + # Status message + status_msg = gr.Textbox( + label="Status", + interactive=False, + placeholder="Ready to process images..." + ) + + # Example images + gr.Markdown("### 📸 Example Images") + gr.Markdown("You can test the tool with wound images from the dataset.") + + # Connect the button to the segmentation function + def process_image(image): + mask, overlay = app.segment_wound(image) + if mask is None: + return None, None, overlay # overlay contains error message + return mask, overlay, "Segmentation completed successfully!" + + segment_btn.click( + fn=process_image, + inputs=[input_image], + outputs=[mask_output, overlay_output, status_msg] + ) + + # Auto-process when image is uploaded + input_image.change( + fn=process_image, + inputs=[input_image], + outputs=[mask_output, overlay_output, status_msg] + ) + + return interface + +if __name__ == "__main__": + # Create and launch the interface + interface = create_gradio_interface() + interface.launch( + server_name="0.0.0.0", + server_port=7860, + share=True, + show_error=True + ) \ No newline at end of file diff --git a/temp_files/test1.txt b/temp_files/test1.txt new file mode 100644 index 0000000000000000000000000000000000000000..c37b7fe936a455becac41a6bd7eb58e3275d229f --- /dev/null +++ b/temp_files/test1.txt @@ -0,0 +1,843 @@ +import glob +import gradio as gr +import matplotlib +import numpy as np +from PIL import Image +import torch +import tempfile +from gradio_imageslider import ImageSlider +import plotly.graph_objects as go +import plotly.express as px +import open3d as o3d +from depth_anything_v2.dpt import DepthAnythingV2 +import os +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.keras.preprocessing import image as keras_image +import base64 +from io import BytesIO +import gdown +import spaces +import cv2 +from skimage import filters, morphology, measure +from skimage.segmentation import clear_border + +# --- LINEAR INITIALIZATION - NO MODULAR FUNCTIONS --- +print("Starting linear initialization for ZeroGPU compatibility...") + +# Define path and file ID +checkpoint_dir = "checkpoints" +os.makedirs(checkpoint_dir, exist_ok=True) + +model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") +gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" + +# Download if not already present +if not os.path.exists(model_file): + print("Downloading model from Google Drive...") + gdown.download(gdrive_url, model_file, quiet=False) + +# --- TensorFlow: Check GPU Availability --- +gpus = tf.config.list_physical_devices('GPU') +if gpus: + print("TensorFlow is using GPU") +else: + print("TensorFlow is using CPU") + +# --- Load Wound Classification Model and Class Labels --- +wound_model = load_model("/home/user/app/keras_model.h5") +with open("/home/user/app/labels.txt", "r") as f: + class_labels = [line.strip().split(maxsplit=1)[1] for line in f] + +# --- PyTorch: Set Device and Load Depth Model --- +print("Initializing PyTorch device...") +map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") +print(f"Using PyTorch device: {map_device}") + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} +encoder = 'vitl' +depth_model = DepthAnythingV2(**model_configs[encoder]) +state_dict = torch.load( + f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth', + map_location=map_device +) +depth_model.load_state_dict(state_dict) +depth_model = depth_model.to(map_device).eval() + +# --- Custom CSS for unified dark theme --- +css = """ +.gradio-container { + font-family: 'Segoe UI', sans-serif; + background-color: #121212; + color: #ffffff; + padding: 20px; +} +.gr-button { + background-color: #2c3e50; + color: white; + border-radius: 10px; +} +.gr-button:hover { + background-color: #34495e; +} +.gr-html, .gr-html div { + white-space: normal !important; + overflow: visible !important; + text-overflow: unset !important; + word-break: break-word !important; +} +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 80vh; +} +#img-display-output { + max-height: 80vh; +} +#download { + height: 62px; +} +h1 { + text-align: center; + font-size: 3rem; + font-weight: bold; + margin: 2rem 0; + color: #ffffff; +} +h2 { + color: #ffffff; + text-align: center; + margin: 1rem 0; +} +.gr-tabs { + background-color: #1e1e1e; + border-radius: 10px; + padding: 10px; +} +.gr-tab-nav { + background-color: #2c3e50; + border-radius: 8px; +} +.gr-tab-nav button { + color: #ffffff !important; +} +.gr-tab-nav button.selected { + background-color: #34495e !important; +} +""" + +# --- LINEAR FUNCTION DEFINITIONS (NO MODULAR CALLS) --- + +# Wound Classification Functions +def preprocess_input(img): + img = img.resize((224, 224)) + arr = keras_image.img_to_array(img) + arr = arr / 255.0 + return np.expand_dims(arr, axis=0) + +def get_reasoning_from_gemini(img, prediction): + try: + explanations = { + "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", + "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", + "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", + "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", + "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." + } + return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") + except Exception as e: + return f"(Reasoning unavailable: {str(e)})" + +@spaces.GPU +def classify_wound_image(img): + if img is None: + return "
    No image provided
    ", "" + + img_array = preprocess_input(img) + predictions = wound_model.predict(img_array, verbose=0)[0] + pred_idx = int(np.argmax(predictions)) + pred_class = class_labels[pred_idx] + + reasoning_text = get_reasoning_from_gemini(img, pred_class) + + predicted_card = f""" +
    +
    + Predicted Wound Type +
    +
    + {pred_class} +
    +
    + """ + + reasoning_card = f""" +
    +
    + Reasoning +
    +
    + {reasoning_text} +
    +
    + """ + + return predicted_card, reasoning_card + +# Depth Estimation Functions +@spaces.GPU +def predict_depth(image): + return depth_model.infer_image(image) + +def calculate_max_points(image): + if image is None: + return 10000 + h, w = image.shape[:2] + max_points = h * w * 3 + return max(1000, min(max_points, 300000)) + +def update_slider_on_image_upload(image): + max_points = calculate_max_points(image) + default_value = min(10000, max_points // 10) + return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, + label=f"Number of 3D points (max: {max_points:,})") + +@spaces.GPU +def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): + h, w = depth_map.shape + step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) + + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + x_cam = (x_coords - w / 2) / focal_length_x + y_cam = (y_coords - h / 2) / focal_length_y + depth_values = depth_map[::step, ::step] + + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) + image_colors = image[::step, ::step, :] + colors = image_colors.reshape(-1, 3) / 255.0 + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + +@spaces.GPU +def reconstruct_surface_mesh_from_point_cloud(pcd): + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) + pcd.orient_normals_consistent_tangent_plane(k=50) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) + return mesh + +@spaces.GPU +def create_enhanced_3d_visualization(image, depth_map, max_points=10000): + h, w = depth_map.shape + step = max(1, int(np.sqrt(h * w / max_points))) + + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + focal_length = 470.4 + x_cam = (x_coords - w / 2) / focal_length + y_cam = (y_coords - h / 2) / focal_length + depth_values = depth_map[::step, ::step] + + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + x_flat = x_3d.flatten() + y_flat = y_3d.flatten() + z_flat = z_3d.flatten() + + image_colors = image[::step, ::step, :] + colors_flat = image_colors.reshape(-1, 3) + + fig = go.Figure(data=[go.Scatter3d( + x=x_flat, + y=y_flat, + z=z_flat, + mode='markers', + marker=dict( + size=1.5, + color=colors_flat, + opacity=0.9 + ), + hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + + 'Depth: %{z:.2f}
    ' + + '' + )]) + + fig.update_layout( + title="3D Point Cloud Visualization (Camera Projection)", + scene=dict( + xaxis_title="X (meters)", + yaxis_title="Y (meters)", + zaxis_title="Z (meters)", + camera=dict( + eye=dict(x=2.0, y=2.0, z=2.0), + center=dict(x=0, y=0, z=0), + up=dict(x=0, y=0, z=1) + ), + aspectmode='data' + ), + width=700, + height=600 + ) + + return fig + +def on_depth_submit(image, num_points, focal_x, focal_y): + original_image = image.copy() + h, w = image.shape[:2] + + depth = predict_depth(image[:, :, ::-1]) + + raw_depth = Image.fromarray(depth.astype('uint16')) + tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + raw_depth.save(tmp_raw_depth.name) + + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + norm_depth = norm_depth.astype(np.uint8) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) + + gray_depth = Image.fromarray(norm_depth) + tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + gray_depth.save(tmp_gray_depth.name) + + pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) + mesh = reconstruct_surface_mesh_from_point_cloud(pcd) + + tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) + o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) + + depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) + + return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] + +# Wound Severity Analysis Functions +@spaces.GPU +def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5): + pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 + wound_mask = (mask > 127) + wound_depths = depth_map[wound_mask] + total_area = np.sum(wound_mask) * pixel_area_cm2 + + shallow = wound_depths < 3 + moderate = (wound_depths >= 3) & (wound_depths < 6) + deep = wound_depths >= 6 + + shallow_area = np.sum(shallow) * pixel_area_cm2 + moderate_area = np.sum(moderate) * pixel_area_cm2 + deep_area = np.sum(deep) * pixel_area_cm2 + deep_ratio = deep_area / total_area if total_area > 0 else 0 + + return { + 'total_area_cm2': total_area, + 'shallow_area_cm2': shallow_area, + 'moderate_area_cm2': moderate_area, + 'deep_area_cm2': deep_area, + 'deep_ratio': deep_ratio, + 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0 + } + +def classify_wound_severity_by_area(depth_stats): + total = depth_stats['total_area_cm2'] + deep = depth_stats['deep_area_cm2'] + moderate = depth_stats['moderate_area_cm2'] + + if total == 0: + return "Unknown" + + if deep > 2 or (deep / total) > 0.3: + return "Severe" + elif moderate > 1.5 or (moderate / total) > 0.4: + return "Moderate" + else: + return "Mild" + +def get_severity_description(severity): + descriptions = { + "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.", + "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.", + "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.", + "Unknown": "Unable to determine severity due to insufficient data." + } + return descriptions.get(severity, "Severity assessment unavailable.") + +def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5): + if image is None or depth_map is None or wound_mask is None: + return "❌ Please upload image, depth map, and wound mask." + + if len(wound_mask.shape) == 3: + wound_mask = np.mean(wound_mask, axis=2) + + if depth_map.shape[:2] != wound_mask.shape[:2]: + from PIL import Image + mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) + mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) + wound_mask = np.array(mask_pil) + + stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm) + severity = classify_wound_severity_by_area(stats) + + severity_color = { + "Mild": "#4CAF50", + "Moderate": "#FF9800", + "Severe": "#F44336" + }.get(severity, "#9E9E9E") + + report = f""" +
    +
    + 🩹 Wound Severity Analysis +
    + +
    +
    +
    + 📏 Area Measurements +
    +
    +
    🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
    +
    🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
    +
    🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
    +
    🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
    +
    +
    + +
    +
    + 📊 Depth Analysis +
    +
    +
    🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
    +
    📏 Max Depth: {stats['max_depth']:.1f} mm
    +
    Pixel Spacing: {pixel_spacing_mm} mm
    +
    +
    +
    + +
    +
    + 🎯 Predicted Severity: {severity} +
    +
    + {get_severity_description(severity)} +
    +
    +
    + """ + + return report + +# Automatic Wound Mask Generation Functions +def create_automatic_wound_mask(image, method='adaptive'): + if image is None: + return None + + if len(image.shape) == 3: + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + gray = image.copy() + + if method == 'adaptive': + mask = adaptive_threshold_segmentation(gray) + elif method == 'otsu': + mask = otsu_threshold_segmentation(gray) + elif method == 'color': + mask = color_based_segmentation(image) + elif method == 'combined': + mask = combined_segmentation(image, gray) + else: + mask = adaptive_threshold_segmentation(gray) + + return mask + +def adaptive_threshold_segmentation(gray): + blurred = cv2.GaussianBlur(gray, (15, 15), 0) + thresh = cv2.adaptiveThreshold( + blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5 + ) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 1000: + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def otsu_threshold_segmentation(gray): + blurred = cv2.GaussianBlur(gray, (15, 15), 0) + _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 800: + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def color_based_segmentation(image): + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + lower_red1 = np.array([0, 30, 30]) + upper_red1 = np.array([15, 255, 255]) + lower_red2 = np.array([160, 30, 30]) + upper_red2 = np.array([180, 255, 255]) + + mask1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask2 = cv2.inRange(hsv, lower_red2, upper_red2) + red_mask = mask1 + mask2 + + lower_yellow = np.array([15, 30, 30]) + upper_yellow = np.array([35, 255, 255]) + yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow) + + lower_brown = np.array([10, 50, 20]) + upper_brown = np.array([20, 255, 200]) + brown_mask = cv2.inRange(hsv, lower_brown, upper_brown) + + color_mask = red_mask + yellow_mask + brown_mask + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel) + color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel) + + contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(color_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 600: + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def combined_segmentation(image, gray): + adaptive_mask = adaptive_threshold_segmentation(gray) + otsu_mask = otsu_threshold_segmentation(gray) + color_mask = color_based_segmentation(image) + + combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask) + combined_mask = cv2.bitwise_or(combined_mask, color_mask) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20)) + combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) + + contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(combined_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 500: + cv2.fillPoly(mask_clean, [contour], 255) + + if np.sum(mask_clean) == 0: + mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical') + + return mask_clean + +def create_realistic_wound_mask(image_shape, method='elliptical'): + h, w = image_shape[:2] + mask = np.zeros((h, w), dtype=np.uint8) + + if method == 'elliptical': + center = (w // 2, h // 2) + radius_x = min(w, h) // 3 + radius_y = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + ellipse = ((x - center[0])**2 / (radius_x**2) + + (y - center[1])**2 / (radius_y**2)) <= 1 + + noise = np.random.random((h, w)) > 0.8 + mask = (ellipse | noise).astype(np.uint8) * 255 + + elif method == 'irregular': + center = (w // 2, h // 2) + radius = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius + + extensions = np.zeros_like(base_circle) + for i in range(3): + angle = i * 2 * np.pi / 3 + ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) + ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) + ext_radius = radius // 3 + + ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius + extensions = extensions | ext_circle + + mask = (base_circle | extensions).astype(np.uint8) * 255 + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask + +def post_process_wound_mask(mask, min_area=100): + if mask is None: + return None + + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + + for contour in contours: + area = cv2.contourArea(contour) + if area >= min_area: + cv2.fillPoly(mask_clean, [contour], 255) + + mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) + + return mask_clean + +def create_sample_wound_mask(image_shape, center=None, radius=50): + if center is None: + center = (image_shape[1] // 2, image_shape[0] // 2) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + y, x = np.ogrid[:image_shape[0], :image_shape[1]] + + dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) + mask[dist_from_center <= radius] = 255 + + return mask + +# --- MAIN GRADIO INTERFACE (LINEAR EXECUTION) --- +print("Creating Gradio interface...") + +with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: + gr.HTML("

    Wound Analysis & Depth Estimation System

    ") + gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") + + shared_image = gr.State() + + with gr.Tabs(): + # Tab 1: Wound Classification + with gr.Tab("1. Wound Classification"): + gr.Markdown("### Step 1: Upload and classify your wound image") + gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") + + with gr.Row(): + with gr.Column(scale=1): + wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) + + with gr.Column(scale=1): + wound_prediction_box = gr.HTML() + wound_reasoning_box = gr.HTML() + + with gr.Row(): + pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg") + pass_status = gr.HTML("") + + wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, + outputs=[wound_prediction_box, wound_reasoning_box]) + + wound_image_input.change( + fn=lambda img: img, + inputs=[wound_image_input], + outputs=[shared_image] + ) + + # Tab 2: Depth Estimation + with gr.Tab("2. Depth Estimation & 3D Visualization"): + gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") + gr.Markdown("This module creates depth maps and 3D point clouds from your images.") + + with gr.Row(): + depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') + + with gr.Row(): + depth_submit = gr.Button(value="Compute Depth", variant="primary") + load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary") + points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, + label="Number of 3D points (upload image to update max)") + + with gr.Row(): + focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length X (pixels)") + focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length Y (pixels)") + + with gr.Row(): + gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") + raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") + point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") + + gr.Markdown("### 3D Point Cloud Visualization") + gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") + depth_3d_plot = gr.Plot(label="3D Point Cloud") + + depth_map_state = gr.State() + + # Tab 3: Wound Severity Analysis + with gr.Tab("3. 🩹 Wound Severity Analysis"): + gr.Markdown("### Step 3: Analyze wound severity using depth maps") + gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") + + with gr.Row(): + severity_input_image = gr.Image(label="Original Image", type='numpy') + severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') + + with gr.Row(): + wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy') + severity_output = gr.HTML(label="Severity Analysis Report") + + gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.") + + with gr.Row(): + auto_severity_button = gr.Button("🤖 Auto-Analyze Severity", variant="primary", size="lg") + manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg") + pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, + label="Pixel Spacing (mm/pixel)") + + gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") + + with gr.Row(): + segmentation_method = gr.Dropdown( + choices=["combined", "adaptive", "otsu", "color"], + value="combined", + label="Segmentation Method", + info="Choose automatic segmentation method" + ) + min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100, + label="Minimum Area (pixels)", + info="Minimum wound area to detect") + + with gr.Row(): + load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") + sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary") + realistic_mask_btn = gr.Button("🏥 Generate Realistic Mask", variant="secondary") + preview_mask_btn = gr.Button("👁️ Preview Auto Mask", variant="secondary") + + gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.") + + # Event handlers + def generate_sample_mask(image): + if image is None: + return None, "❌ Please load an image first." + sample_mask = create_sample_wound_mask(image.shape) + return sample_mask, "✅ Sample circular wound mask generated!" + + def generate_realistic_mask(image): + if image is None: + return None, "❌ Please load an image first." + realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical') + return realistic_mask, "✅ Realistic elliptical wound mask generated!" + + def load_depth_to_severity(depth_map, original_image): + if depth_map is None: + return None, None, "❌ No depth map available. Please compute depth in Tab 2 first." + return depth_map, original_image, "✅ Depth map loaded successfully!" + + def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + + def post_process_with_area(mask): + return post_process_wound_mask(mask, min_area=min_area) + + auto_mask = create_automatic_wound_mask(image, method=seg_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask." + + processed_mask = post_process_with_area(auto_mask) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask." + + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing) + + def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + if wound_mask is None: + return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)." + return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing) + + def preview_auto_mask(image, seg_method, min_area): + if image is None: + return None, "❌ Please load an image first." + auto_mask = create_automatic_wound_mask(image, method=seg_method) + if auto_mask is None: + return None, "❌ Failed to generate automatic wound mask." + processed_mask = post_process_wound_mask(auto_mask, min_area=min_area) + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return None, "❌ No wound region detected. Try adjusting parameters." + return processed_mask, f"✅ Auto mask generated using {seg_method} method!" + + def load_shared_image(shared_img): + if shared_img is None: + return gr.Image(), "❌ No image available from classification tab" + if hasattr(shared_img, 'convert'): + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab" + else: + return shared_img, "✅ Image loaded from classification tab" + + def pass_image_to_depth(img): + if img is None: + return "❌ No image uploaded in classification tab" + return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" + + def on_depth_submit_with_state(image, num_points, focal_x, focal_y): + results = on_depth_submit(image, num_points, focal_x, focal_y) + depth_map = None + if image is not None: + depth = predict_depth(image[:, :, ::-1]) + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth_map = norm_depth.astype(np.uint8) + return results + [depth_map] + + # Connect all event handlers + sample_mask_btn.click(fn=generate_sample_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()]) + realistic_mask_btn.click(fn=generate_realistic_mask, inputs=[severity_input_image], outputs=[wound_mask_input, gr.HTML()]) + depth_input_image.change(fn=update_slider_on_image_upload, inputs=[depth_input_image], outputs=[points_slider]) + depth_submit.click(on_depth_submit_with_state, inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) + load_depth_btn.click(fn=load_depth_to_severity, inputs=[depth_map_state, depth_input_image], outputs=[severity_depth_map, severity_input_image, gr.HTML()]) + auto_severity_button.click(fn=run_auto_severity_analysis, inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, segmentation_method, min_area_slider], outputs=[severity_output]) + manual_severity_button.click(fn=run_manual_severity_analysis, inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], outputs=[severity_output]) + preview_mask_btn.click(fn=preview_auto_mask, inputs=[severity_input_image, segmentation_method, min_area_slider], outputs=[wound_mask_input, gr.HTML()]) + load_shared_btn.click(fn=load_shared_image, inputs=[shared_image], outputs=[depth_input_image, gr.HTML()]) + pass_to_depth_btn.click(fn=pass_image_to_depth, inputs=[shared_image], outputs=[pass_status]) + +print("Gradio interface created successfully!") + +if __name__ == '__main__': + print("Launching app...") + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True + ) diff --git a/temp_files/test2.txt b/temp_files/test2.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd3ddd83a996c7875d6db6b80321e70abb10495f --- /dev/null +++ b/temp_files/test2.txt @@ -0,0 +1,1063 @@ +import glob +import gradio as gr +import matplotlib +import numpy as np +from PIL import Image +import torch +import tempfile +from gradio_imageslider import ImageSlider +import plotly.graph_objects as go +import plotly.express as px +import open3d as o3d +from depth_anything_v2.dpt import DepthAnythingV2 +import os +import tensorflow as tf +from tensorflow.keras.models import load_model +from tensorflow.keras.preprocessing import image as keras_image +import base64 +from io import BytesIO +import gdown +import spaces + +# Define path and file ID +checkpoint_dir = "checkpoints" +os.makedirs(checkpoint_dir, exist_ok=True) + +model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") +gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" + +# Download if not already present +if not os.path.exists(model_file): + print("Downloading model from Google Drive...") + gdown.download(gdrive_url, model_file, quiet=False) + +# --- TensorFlow: Check GPU Availability --- +gpus = tf.config.list_physical_devices('GPU') +if gpus: + print("TensorFlow is using GPU") +else: + print("TensorFlow is using CPU") + +# --- Load Wound Classification Model and Class Labels --- +wound_model = load_model("/home/user/app/keras_model.h5") +with open("/home/user/app/labels.txt", "r") as f: + class_labels = [line.strip().split(maxsplit=1)[1] for line in f] + +# --- PyTorch: Set Device and Load Depth Model --- +map_device = torch.device("cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu") +print(f"Using PyTorch device: {map_device}") + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} +} +encoder = 'vitl' +depth_model = DepthAnythingV2(**model_configs[encoder]) +state_dict = torch.load( + f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth', + map_location=map_device +) +depth_model.load_state_dict(state_dict) +depth_model = depth_model.to(map_device).eval() + + +# --- Custom CSS for unified dark theme --- +css = """ +.gradio-container { + font-family: 'Segoe UI', sans-serif; + background-color: #121212; + color: #ffffff; + padding: 20px; +} +.gr-button { + background-color: #2c3e50; + color: white; + border-radius: 10px; +} +.gr-button:hover { + background-color: #34495e; +} +.gr-html, .gr-html div { + white-space: normal !important; + overflow: visible !important; + text-overflow: unset !important; + word-break: break-word !important; +} +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 80vh; +} +#img-display-output { + max-height: 80vh; +} +#download { + height: 62px; +} +h1 { + text-align: center; + font-size: 3rem; + font-weight: bold; + margin: 2rem 0; + color: #ffffff; +} +h2 { + color: #ffffff; + text-align: center; + margin: 1rem 0; +} +.gr-tabs { + background-color: #1e1e1e; + border-radius: 10px; + padding: 10px; +} +.gr-tab-nav { + background-color: #2c3e50; + border-radius: 8px; +} +.gr-tab-nav button { + color: #ffffff !important; +} +.gr-tab-nav button.selected { + background-color: #34495e !important; +} +""" + +# --- Wound Classification Functions --- +def preprocess_input(img): + img = img.resize((224, 224)) + arr = keras_image.img_to_array(img) + arr = arr / 255.0 + return np.expand_dims(arr, axis=0) + +def get_reasoning_from_gemini(img, prediction): + try: + # For now, return a simple explanation without Gemini API to avoid typing issues + # In production, you would implement the proper Gemini API call here + explanations = { + "Abrasion": "This appears to be an abrasion wound, characterized by superficial damage to the skin surface. The wound shows typical signs of friction or scraping injury.", + "Burn": "This wound exhibits characteristics consistent with a burn injury, showing tissue damage from heat, chemicals, or radiation exposure.", + "Laceration": "This wound displays the irregular edges and tissue tearing typical of a laceration, likely caused by blunt force trauma.", + "Puncture": "This wound shows a small, deep entry point characteristic of puncture wounds, often caused by sharp, pointed objects.", + "Ulcer": "This wound exhibits the characteristics of an ulcer, showing tissue breakdown and potential underlying vascular or pressure issues." + } + + return explanations.get(prediction, f"This wound has been classified as {prediction}. Please consult with a healthcare professional for detailed assessment.") + + except Exception as e: + return f"(Reasoning unavailable: {str(e)})" + +@spaces.GPU +def classify_wound_image(img): + if img is None: + return "
    No image provided
    ", "" + + img_array = preprocess_input(img) + predictions = wound_model.predict(img_array, verbose=0)[0] + pred_idx = int(np.argmax(predictions)) + pred_class = class_labels[pred_idx] + + # Get reasoning from Gemini + reasoning_text = get_reasoning_from_gemini(img, pred_class) + + # Prediction Card + predicted_card = f""" +
    +
    + Predicted Wound Type +
    +
    + {pred_class} +
    +
    + """ + + # Reasoning Card + reasoning_card = f""" +
    +
    + Reasoning +
    +
    + {reasoning_text} +
    +
    + """ + + return predicted_card, reasoning_card + +# --- Wound Severity Estimation Functions --- +@spaces.GPU +def compute_depth_area_statistics(depth_map, mask, pixel_spacing_mm=0.5): + """Compute area statistics for different depth regions""" + pixel_area_cm2 = (pixel_spacing_mm / 10.0) ** 2 + + # Extract only wound region + wound_mask = (mask > 127) + wound_depths = depth_map[wound_mask] + total_area = np.sum(wound_mask) * pixel_area_cm2 + + # Categorize depth regions + shallow = wound_depths < 3 + moderate = (wound_depths >= 3) & (wound_depths < 6) + deep = wound_depths >= 6 + + shallow_area = np.sum(shallow) * pixel_area_cm2 + moderate_area = np.sum(moderate) * pixel_area_cm2 + deep_area = np.sum(deep) * pixel_area_cm2 + + deep_ratio = deep_area / total_area if total_area > 0 else 0 + + return { + 'total_area_cm2': total_area, + 'shallow_area_cm2': shallow_area, + 'moderate_area_cm2': moderate_area, + 'deep_area_cm2': deep_area, + 'deep_ratio': deep_ratio, + 'max_depth': np.max(wound_depths) if len(wound_depths) > 0 else 0 + } + +def classify_wound_severity_by_area(depth_stats): + """Classify wound severity based on area and depth distribution""" + total = depth_stats['total_area_cm2'] + deep = depth_stats['deep_area_cm2'] + moderate = depth_stats['moderate_area_cm2'] + + if total == 0: + return "Unknown" + + # Severity classification rules + if deep > 2 or (deep / total) > 0.3: + return "Severe" + elif moderate > 1.5 or (moderate / total) > 0.4: + return "Moderate" + else: + return "Mild" + +def analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing_mm=0.5): + """Analyze wound severity from depth map and wound mask""" + if image is None or depth_map is None or wound_mask is None: + return "❌ Please upload image, depth map, and wound mask." + + # Convert wound mask to grayscale if needed + if len(wound_mask.shape) == 3: + wound_mask = np.mean(wound_mask, axis=2) + + # Ensure depth map and mask have same dimensions + if depth_map.shape[:2] != wound_mask.shape[:2]: + # Resize mask to match depth map + from PIL import Image + mask_pil = Image.fromarray(wound_mask.astype(np.uint8)) + mask_pil = mask_pil.resize((depth_map.shape[1], depth_map.shape[0])) + wound_mask = np.array(mask_pil) + + # Compute statistics + stats = compute_depth_area_statistics(depth_map, wound_mask, pixel_spacing_mm) + severity = classify_wound_severity_by_area(stats) + + # Create severity report with color coding + severity_color = { + "Mild": "#4CAF50", # Green + "Moderate": "#FF9800", # Orange + "Severe": "#F44336" # Red + }.get(severity, "#9E9E9E") # Gray for unknown + + report = f""" +
    +
    + 🩹 Wound Severity Analysis +
    + +
    +
    +
    + 📏 Area Measurements +
    +
    +
    🟢 Total Area: {stats['total_area_cm2']:.2f} cm²
    +
    🟩 Shallow (0-3mm): {stats['shallow_area_cm2']:.2f} cm²
    +
    🟨 Moderate (3-6mm): {stats['moderate_area_cm2']:.2f} cm²
    +
    🟥 Deep (>6mm): {stats['deep_area_cm2']:.2f} cm²
    +
    +
    + +
    +
    + 📊 Depth Analysis +
    +
    +
    🔥 Deep Coverage: {stats['deep_ratio']*100:.1f}%
    +
    📏 Max Depth: {stats['max_depth']:.1f} mm
    +
    Pixel Spacing: {pixel_spacing_mm} mm
    +
    +
    +
    + +
    +
    + 🎯 Predicted Severity: {severity} +
    +
    + {get_severity_description(severity)} +
    +
    +
    + """ + + return report + +def get_severity_description(severity): + """Get description for severity level""" + descriptions = { + "Mild": "Superficial wound with minimal tissue damage. Usually heals well with basic care.", + "Moderate": "Moderate tissue involvement requiring careful monitoring and proper treatment.", + "Severe": "Deep tissue damage requiring immediate medical attention and specialized care.", + "Unknown": "Unable to determine severity due to insufficient data." + } + return descriptions.get(severity, "Severity assessment unavailable.") + +def create_sample_wound_mask(image_shape, center=None, radius=50): + """Create a sample circular wound mask for testing""" + if center is None: + center = (image_shape[1] // 2, image_shape[0] // 2) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + y, x = np.ogrid[:image_shape[0], :image_shape[1]] + + # Create circular mask + dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) + mask[dist_from_center <= radius] = 255 + + return mask + +def create_realistic_wound_mask(image_shape, method='elliptical'): + """Create a more realistic wound mask with irregular shapes""" + h, w = image_shape[:2] + mask = np.zeros((h, w), dtype=np.uint8) + + if method == 'elliptical': + # Create elliptical wound mask + center = (w // 2, h // 2) + radius_x = min(w, h) // 3 + radius_y = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + # Add some irregularity to make it more realistic + ellipse = ((x - center[0])**2 / (radius_x**2) + + (y - center[1])**2 / (radius_y**2)) <= 1 + + # Add some noise and irregularity + noise = np.random.random((h, w)) > 0.8 + mask = (ellipse | noise).astype(np.uint8) * 255 + + elif method == 'irregular': + # Create irregular wound mask + center = (w // 2, h // 2) + radius = min(w, h) // 4 + + y, x = np.ogrid[:h, :w] + base_circle = np.sqrt((x - center[0])**2 + (y - center[1])**2) <= radius + + # Add irregular extensions + extensions = np.zeros_like(base_circle) + for i in range(3): + angle = i * 2 * np.pi / 3 + ext_x = int(center[0] + radius * 0.8 * np.cos(angle)) + ext_y = int(center[1] + radius * 0.8 * np.sin(angle)) + ext_radius = radius // 3 + + ext_circle = np.sqrt((x - ext_x)**2 + (y - ext_y)**2) <= ext_radius + extensions = extensions | ext_circle + + mask = (base_circle | extensions).astype(np.uint8) * 255 + + # Apply morphological operations to smooth the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + return mask + +# --- Depth Estimation Functions --- +@spaces.GPU +def predict_depth(image): + return depth_model.infer_image(image) + +def calculate_max_points(image): + """Calculate maximum points based on image dimensions (3x pixel count)""" + if image is None: + return 10000 # Default value + h, w = image.shape[:2] + max_points = h * w * 3 + # Ensure minimum and reasonable maximum values + return max(1000, min(max_points, 300000)) + +def update_slider_on_image_upload(image): + """Update the points slider when an image is uploaded""" + max_points = calculate_max_points(image) + default_value = min(10000, max_points // 10) # 10% of max points as default + return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, + label=f"Number of 3D points (max: {max_points:,})") + +@spaces.GPU +def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=30000): + """Create a point cloud from depth map using camera intrinsics with high detail""" + h, w = depth_map.shape + + # Use smaller step for higher detail (reduced downsampling) + step = max(1, int(np.sqrt(h * w / max_points) * 0.5)) # Reduce step size for more detail + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + x_cam = (x_coords - w / 2) / focal_length_x + y_cam = (y_coords - h / 2) / focal_length_y + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors = image_colors.reshape(-1, 3) / 255.0 + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + +@spaces.GPU +def reconstruct_surface_mesh_from_point_cloud(pcd): + """Convert point cloud to a mesh using Poisson reconstruction with very high detail.""" + # Estimate and orient normals with high precision + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.005, max_nn=50)) + pcd.orient_normals_consistent_tangent_plane(k=50) + + # Create surface mesh with maximum detail (depth=12 for very high resolution) + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=12) + + # Return mesh without filtering low-density vertices + return mesh + +@spaces.GPU +def create_enhanced_3d_visualization(image, depth_map, max_points=10000): + """Create an enhanced 3D visualization using proper camera projection""" + h, w = depth_map.shape + + # Downsample to avoid too many points for performance + step = max(1, int(np.sqrt(h * w / max_points))) + + # Create mesh grid for camera coordinates + y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] + + # Convert to camera coordinates (normalized by focal length) + focal_length = 470.4 # Default focal length + x_cam = (x_coords - w / 2) / focal_length + y_cam = (y_coords - h / 2) / focal_length + + # Get depth values + depth_values = depth_map[::step, ::step] + + # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) + x_3d = x_cam * depth_values + y_3d = y_cam * depth_values + z_3d = depth_values + + # Flatten arrays + x_flat = x_3d.flatten() + y_flat = y_3d.flatten() + z_flat = z_3d.flatten() + + # Get corresponding image colors + image_colors = image[::step, ::step, :] + colors_flat = image_colors.reshape(-1, 3) + + # Create 3D scatter plot with proper camera projection + fig = go.Figure(data=[go.Scatter3d( + x=x_flat, + y=y_flat, + z=z_flat, + mode='markers', + marker=dict( + size=1.5, + color=colors_flat, + opacity=0.9 + ), + hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
    ' + + 'Depth: %{z:.2f}
    ' + + '' + )]) + + fig.update_layout( + title="3D Point Cloud Visualization (Camera Projection)", + scene=dict( + xaxis_title="X (meters)", + yaxis_title="Y (meters)", + zaxis_title="Z (meters)", + camera=dict( + eye=dict(x=2.0, y=2.0, z=2.0), + center=dict(x=0, y=0, z=0), + up=dict(x=0, y=0, z=1) + ), + aspectmode='data' + ), + width=700, + height=600 + ) + + return fig + +def on_depth_submit(image, num_points, focal_x, focal_y): + original_image = image.copy() + + h, w = image.shape[:2] + + # Predict depth using the model + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + + # Save raw 16-bit depth + raw_depth = Image.fromarray(depth.astype('uint16')) + tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + raw_depth.save(tmp_raw_depth.name) + + # Normalize and convert to grayscale for display + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + norm_depth = norm_depth.astype(np.uint8) + colored_depth = (matplotlib.colormaps.get_cmap('Spectral_r')(norm_depth)[:, :, :3] * 255).astype(np.uint8) + + gray_depth = Image.fromarray(norm_depth) + tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + gray_depth.save(tmp_gray_depth.name) + + # Create point cloud + pcd = create_point_cloud(original_image, norm_depth, focal_x, focal_y, max_points=num_points) + + # Reconstruct mesh from point cloud + mesh = reconstruct_surface_mesh_from_point_cloud(pcd) + + # Save mesh with faces as .ply + tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) + o3d.io.write_triangle_mesh(tmp_pointcloud.name, mesh) + + # Create enhanced 3D scatter plot visualization + depth_3d = create_enhanced_3d_visualization(original_image, norm_depth, max_points=num_points) + + return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] + +# --- Automatic Wound Mask Generation Functions --- +import cv2 +from skimage import filters, morphology, measure +from skimage.segmentation import clear_border + +def create_automatic_wound_mask(image, method='adaptive'): + """ + Automatically generate wound mask from image using various segmentation methods + + Args: + image: Input image (numpy array) + method: Segmentation method ('adaptive', 'otsu', 'color', 'combined') + + Returns: + mask: Binary wound mask + """ + if image is None: + return None + + # Convert to grayscale if needed + if len(image.shape) == 3: + gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + else: + gray = image.copy() + + # Apply different segmentation methods + if method == 'adaptive': + mask = adaptive_threshold_segmentation(gray) + elif method == 'otsu': + mask = otsu_threshold_segmentation(gray) + elif method == 'color': + mask = color_based_segmentation(image) + elif method == 'combined': + mask = combined_segmentation(image, gray) + else: + mask = adaptive_threshold_segmentation(gray) + + return mask + +def adaptive_threshold_segmentation(gray): + """Use adaptive thresholding for wound segmentation""" + # Apply Gaussian blur to reduce noise + blurred = cv2.GaussianBlur(gray, (15, 15), 0) + + # Adaptive thresholding with larger block size + thresh = cv2.adaptiveThreshold( + blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 25, 5 + ) + + # Morphological operations to clean up the mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Find contours and keep only the largest ones + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Create a new mask with only large contours + mask_clean = np.zeros_like(mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 1000: # Minimum area threshold + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def otsu_threshold_segmentation(gray): + """Use Otsu's thresholding for wound segmentation""" + # Apply Gaussian blur + blurred = cv2.GaussianBlur(gray, (15, 15), 0) + + # Otsu's thresholding + _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + # Morphological operations + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Find contours and keep only the largest ones + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Create a new mask with only large contours + mask_clean = np.zeros_like(mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 800: # Minimum area threshold + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def color_based_segmentation(image): + """Use color-based segmentation for wound detection""" + # Convert to different color spaces + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # Create masks for different color ranges (wound-like colors) + # Reddish/brownish wound colors in HSV - broader ranges + lower_red1 = np.array([0, 30, 30]) + upper_red1 = np.array([15, 255, 255]) + lower_red2 = np.array([160, 30, 30]) + upper_red2 = np.array([180, 255, 255]) + + mask1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask2 = cv2.inRange(hsv, lower_red2, upper_red2) + red_mask = mask1 + mask2 + + # Yellowish wound colors - broader range + lower_yellow = np.array([15, 30, 30]) + upper_yellow = np.array([35, 255, 255]) + yellow_mask = cv2.inRange(hsv, lower_yellow, upper_yellow) + + # Brownish wound colors + lower_brown = np.array([10, 50, 20]) + upper_brown = np.array([20, 255, 200]) + brown_mask = cv2.inRange(hsv, lower_brown, upper_brown) + + # Combine color masks + color_mask = red_mask + yellow_mask + brown_mask + + # Clean up the mask with larger kernels + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) + color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_CLOSE, kernel) + color_mask = cv2.morphologyEx(color_mask, cv2.MORPH_OPEN, kernel) + + # Find contours and keep only the largest ones + contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Create a new mask with only large contours + mask_clean = np.zeros_like(color_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 600: # Minimum area threshold + cv2.fillPoly(mask_clean, [contour], 255) + + return mask_clean + +def combined_segmentation(image, gray): + """Combine multiple segmentation methods for better results""" + # Get masks from different methods + adaptive_mask = adaptive_threshold_segmentation(gray) + otsu_mask = otsu_threshold_segmentation(gray) + color_mask = color_based_segmentation(image) + + # Combine masks (union) + combined_mask = cv2.bitwise_or(adaptive_mask, otsu_mask) + combined_mask = cv2.bitwise_or(combined_mask, color_mask) + + # Apply additional morphological operations to clean up + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20)) + combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) + + # Find contours and keep only the largest ones + contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Create a new mask with only large contours + mask_clean = np.zeros_like(combined_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area > 500: # Minimum area threshold + cv2.fillPoly(mask_clean, [contour], 255) + + # If no large contours found, create a realistic wound mask + if np.sum(mask_clean) == 0: + mask_clean = create_realistic_wound_mask(combined_mask.shape, method='elliptical') + + return mask_clean + +def post_process_wound_mask(mask, min_area=100): + """Post-process the wound mask to remove noise and small objects""" + if mask is None: + return None + + # Convert to binary if needed + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + + # Apply morphological operations to clean up + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + + # Remove small objects using OpenCV + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + mask_clean = np.zeros_like(mask) + + for contour in contours: + area = cv2.contourArea(contour) + if area >= min_area: + cv2.fillPoly(mask_clean, [contour], 255) + + # Fill holes + mask_clean = cv2.morphologyEx(mask_clean, cv2.MORPH_CLOSE, kernel) + + return mask_clean + +def analyze_wound_severity_auto(image, depth_map, pixel_spacing_mm=0.5, segmentation_method='combined'): + """Analyze wound severity with automatic mask generation""" + if image is None or depth_map is None: + return "❌ Please provide both image and depth map." + + # Generate automatic wound mask + auto_mask = create_automatic_wound_mask(image, method=segmentation_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask." + + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=500) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected. Try adjusting segmentation parameters or upload a manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing_mm) + +# --- Main Gradio Interface --- +with gr.Blocks(css=css, title="Wound Analysis & Depth Estimation") as demo: + gr.HTML("

    Wound Analysis & Depth Estimation System

    ") + gr.Markdown("### Comprehensive wound analysis with classification and 3D depth mapping capabilities") + + # Shared image state + shared_image = gr.State() + + with gr.Tabs(): + # Tab 1: Wound Classification + with gr.Tab("1. Wound Classification"): + gr.Markdown("### Step 1: Upload and classify your wound image") + gr.Markdown("This module analyzes wound images and provides classification with AI-powered reasoning.") + + with gr.Row(): + with gr.Column(scale=1): + wound_image_input = gr.Image(label="Upload Wound Image", type="pil", height=350) + + with gr.Column(scale=1): + wound_prediction_box = gr.HTML() + wound_reasoning_box = gr.HTML() + + # Button to pass image to depth estimation + with gr.Row(): + pass_to_depth_btn = gr.Button("📊 Pass Image to Depth Analysis", variant="secondary", size="lg") + pass_status = gr.HTML("") + + wound_image_input.change(fn=classify_wound_image, inputs=wound_image_input, + outputs=[wound_prediction_box, wound_reasoning_box]) + + # Store image when uploaded for classification + wound_image_input.change( + fn=lambda img: img, + inputs=[wound_image_input], + outputs=[shared_image] + ) + + # Tab 2: Depth Estimation + with gr.Tab("2. Depth Estimation & 3D Visualization"): + gr.Markdown("### Step 2: Generate depth maps and 3D visualizations") + gr.Markdown("This module creates depth maps and 3D point clouds from your images.") + + with gr.Row(): + depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') + depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') + + with gr.Row(): + depth_submit = gr.Button(value="Compute Depth", variant="primary") + load_shared_btn = gr.Button("🔄 Load Image from Classification", variant="secondary") + points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, + label="Number of 3D points (upload image to update max)") + + with gr.Row(): + focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length X (pixels)") + focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, + label="Focal Length Y (pixels)") + + with gr.Row(): + gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") + raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") + point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") + + # 3D Visualization + gr.Markdown("### 3D Point Cloud Visualization") + gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") + depth_3d_plot = gr.Plot(label="3D Point Cloud") + + # Store depth map for severity analysis + depth_map_state = gr.State() + + # Tab 3: Wound Severity Analysis + with gr.Tab("3. 🩹 Wound Severity Analysis"): + gr.Markdown("### Step 3: Analyze wound severity using depth maps") + gr.Markdown("This module analyzes wound severity based on depth distribution and area measurements.") + + with gr.Row(): + severity_input_image = gr.Image(label="Original Image", type='numpy') + severity_depth_map = gr.Image(label="Depth Map (from Tab 2)", type='numpy') + + with gr.Row(): + wound_mask_input = gr.Image(label="Wound Mask (Optional)", type='numpy') + severity_output = gr.HTML(label="Severity Analysis Report") + + gr.Markdown("**Note:** You can either upload a manual mask or use automatic mask generation.") + + with gr.Row(): + auto_severity_button = gr.Button("🤖 Auto-Analyze Severity", variant="primary", size="lg") + manual_severity_button = gr.Button("🔍 Manual Mask Analysis", variant="secondary", size="lg") + pixel_spacing_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, + label="Pixel Spacing (mm/pixel)") + + gr.Markdown("**Pixel Spacing:** Adjust based on your camera calibration. Default is 0.5 mm/pixel.") + + with gr.Row(): + segmentation_method = gr.Dropdown( + choices=["combined", "adaptive", "otsu", "color"], + value="combined", + label="Segmentation Method", + info="Choose automatic segmentation method" + ) + min_area_slider = gr.Slider(minimum=100, maximum=2000, value=500, step=100, + label="Minimum Area (pixels)", + info="Minimum wound area to detect") + + with gr.Row(): + # Load depth map from previous tab + load_depth_btn = gr.Button("🔄 Load Depth Map from Tab 2", variant="secondary") + sample_mask_btn = gr.Button("🎯 Generate Sample Mask", variant="secondary") + realistic_mask_btn = gr.Button("🏥 Generate Realistic Mask", variant="secondary") + preview_mask_btn = gr.Button("👁️ Preview Auto Mask", variant="secondary") + + gr.Markdown("**Options:** Load depth map, generate sample mask, or preview automatic segmentation.") + + # Generate sample mask function + def generate_sample_mask(image): + if image is None: + return None, "❌ Please load an image first." + + sample_mask = create_sample_wound_mask(image.shape) + return sample_mask, "✅ Sample circular wound mask generated!" + + # Generate realistic mask function + def generate_realistic_mask(image): + if image is None: + return None, "❌ Please load an image first." + + realistic_mask = create_realistic_wound_mask(image.shape, method='elliptical') + return realistic_mask, "✅ Realistic elliptical wound mask generated!" + + sample_mask_btn.click( + fn=generate_sample_mask, + inputs=[severity_input_image], + outputs=[wound_mask_input, gr.HTML()] + ) + + realistic_mask_btn.click( + fn=generate_realistic_mask, + inputs=[severity_input_image], + outputs=[wound_mask_input, gr.HTML()] + ) + + # Update slider when image is uploaded + depth_input_image.change( + fn=update_slider_on_image_upload, + inputs=[depth_input_image], + outputs=[points_slider] + ) + + # Modified depth submit function to store depth map + def on_depth_submit_with_state(image, num_points, focal_x, focal_y): + results = on_depth_submit(image, num_points, focal_x, focal_y) + # Extract depth map from results for severity analysis + depth_map = None + if image is not None: + depth = predict_depth(image[:, :, ::-1]) # RGB to BGR if needed + # Normalize depth for severity analysis + norm_depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth_map = norm_depth.astype(np.uint8) + return results + [depth_map] + + depth_submit.click(on_depth_submit_with_state, + inputs=[depth_input_image, points_slider, focal_length_x, focal_length_y], + outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot, depth_map_state]) + + # Load depth map to severity tab + def load_depth_to_severity(depth_map, original_image): + if depth_map is None: + return None, None, "❌ No depth map available. Please compute depth in Tab 2 first." + return depth_map, original_image, "✅ Depth map loaded successfully!" + + load_depth_btn.click( + fn=load_depth_to_severity, + inputs=[depth_map_state, depth_input_image], + outputs=[severity_depth_map, severity_input_image, gr.HTML()] + ) + + # Automatic severity analysis function + def run_auto_severity_analysis(image, depth_map, pixel_spacing, seg_method, min_area): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + + # Update post-processing with user-defined minimum area + def post_process_with_area(mask): + return post_process_wound_mask(mask, min_area=min_area) + + # Generate automatic wound mask + auto_mask = create_automatic_wound_mask(image, method=seg_method) + + if auto_mask is None: + return "❌ Failed to generate automatic wound mask." + + # Post-process the mask + processed_mask = post_process_with_area(auto_mask) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return "❌ No wound region detected. Try adjusting segmentation parameters or use manual mask." + + # Analyze severity using the automatic mask + return analyze_wound_severity(image, depth_map, processed_mask, pixel_spacing) + + # Manual severity analysis function + def run_manual_severity_analysis(image, depth_map, wound_mask, pixel_spacing): + if depth_map is None: + return "❌ Please load depth map from Tab 2 first." + if wound_mask is None: + return "❌ Please upload a wound mask (binary image where white pixels represent the wound area)." + + return analyze_wound_severity(image, depth_map, wound_mask, pixel_spacing) + + # Preview automatic mask function + def preview_auto_mask(image, seg_method, min_area): + if image is None: + return None, "❌ Please load an image first." + + # Generate automatic wound mask + auto_mask = create_automatic_wound_mask(image, method=seg_method) + + if auto_mask is None: + return None, "❌ Failed to generate automatic wound mask." + + # Post-process the mask + processed_mask = post_process_wound_mask(auto_mask, min_area=min_area) + + if processed_mask is None or np.sum(processed_mask > 0) == 0: + return None, "❌ No wound region detected. Try adjusting parameters." + + return processed_mask, f"✅ Auto mask generated using {seg_method} method!" + + # Connect event handlers + auto_severity_button.click( + fn=run_auto_severity_analysis, + inputs=[severity_input_image, severity_depth_map, pixel_spacing_slider, + segmentation_method, min_area_slider], + outputs=[severity_output] + ) + + manual_severity_button.click( + fn=run_manual_severity_analysis, + inputs=[severity_input_image, severity_depth_map, wound_mask_input, pixel_spacing_slider], + outputs=[severity_output] + ) + + preview_mask_btn.click( + fn=preview_auto_mask, + inputs=[severity_input_image, segmentation_method, min_area_slider], + outputs=[wound_mask_input, gr.HTML()] + ) + + # Load shared image from classification tab + def load_shared_image(shared_img): + if shared_img is None: + return gr.Image(), "❌ No image available from classification tab" + + # Convert PIL image to numpy array for depth estimation + if hasattr(shared_img, 'convert'): + # It's a PIL image, convert to numpy + img_array = np.array(shared_img) + return img_array, "✅ Image loaded from classification tab" + else: + # Already numpy array + return shared_img, "✅ Image loaded from classification tab" + + load_shared_btn.click( + fn=load_shared_image, + inputs=[shared_image], + outputs=[depth_input_image, gr.HTML()] + ) + + # Pass image to depth tab function + def pass_image_to_depth(img): + if img is None: + return "❌ No image uploaded in classification tab" + return "✅ Image ready for depth analysis! Switch to tab 2 and click 'Load Image from Classification'" + + pass_to_depth_btn.click( + fn=pass_image_to_depth, + inputs=[shared_image], + outputs=[pass_status] + ) + +if __name__ == '__main__': + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=True + ) \ No newline at end of file diff --git a/temp_files/train.py b/temp_files/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac4e808b034f1397392664df5cb0494c37da0ad --- /dev/null +++ b/temp_files/train.py @@ -0,0 +1,69 @@ +from tensorflow.keras.optimizers import Adam +from keras.callbacks import EarlyStopping +from keras.models import load_model +from keras.utils.generic_utils import CustomObjectScope + +from models.unets import Unet2D +from models.deeplab import Deeplabv3, relu6, DepthwiseConv2D, BilinearUpsampling +from models.FCN import FCN_Vgg16_16s +from models.SegNet import SegNet + +from utils.learning.metrics import dice_coef, precision, recall +from utils.learning.losses import dice_coef_loss +from utils.io.data import DataGen, save_results, save_history, load_data + + +# manually set cuda 10.0 path +#os.system('export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}') +#os.system('export PATH=/usr/local/cuda-10.0/bin:/usr/local/cuda-10.0/NsightCompute-1.0${PATH:+:${PATH}}') + +# Varibales and data generator +input_dim_x=224 +input_dim_y=224 +n_filters = 32 +dataset = 'Medetec_foot_ulcer_224' +data_gen = DataGen('./data/' + dataset + '/', split_ratio=0.2, x=input_dim_x, y=input_dim_y) + +######### Get the deep learning models ######### + +######### Unet ########## +# unet2d = Unet2D(n_filters=n_filters, input_dim_x=None, input_dim_y=None, num_channels=3) +# model, model_name = unet2d.get_unet_model_yuanqing() + +######### MobilenetV2 ########## +model = Deeplabv3(input_shape=(input_dim_x, input_dim_y, 3), classes=1) +model_name = 'MobilenetV2' +with CustomObjectScope({'relu6': relu6,'DepthwiseConv2D': DepthwiseConv2D, 'BilinearUpsampling': BilinearUpsampling}): + model = load_model('training_history/2019-12-19 01%3A53%3A15.480800.hdf5' + , custom_objects={'dice_coef': dice_coef, 'precision':precision, 'recall':recall}) + +######### Vgg16 ########## +# model, model_name = FCN_Vgg16_16s(input_shape=(input_dim_x, input_dim_y, 3)) + +######### SegNet ########## +# segnet = SegNet(n_filters, input_dim_x, input_dim_y, num_channels=3) +# model, model_name = segnet.get_SegNet() + +# plot_model(model, to_file=model_name+'.png') + +# training +batch_size = 2 +epochs = 2000 +learning_rate = 1e-4 +loss = 'binary_crossentropy' + +es = EarlyStopping(monitor='val_dice_coef', patience=200, mode='max', restore_best_weights=True) +#training_history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs +# , validation_split=0.2, verbose=1, callbacks=[]) + +model.summary() +model.compile(optimizer=Adam(lr=learning_rate), loss=loss, metrics=[dice_coef, precision, recall]) +training_history = model.fit_generator(data_gen.generate_data(batch_size=batch_size, train=True), + steps_per_epoch=int(data_gen.get_num_data_points(train=True) / batch_size), + callbacks=[es], + validation_data=data_gen.generate_data(batch_size=batch_size, val=True), + validation_steps=int(data_gen.get_num_data_points(val=True) / batch_size), + epochs=epochs) +### save the model weight file and its training history +save_history(model, model_name, training_history, dataset, n_filters, epochs, learning_rate, loss, color_space='RGB', + path='./training_history/') \ No newline at end of file diff --git a/training_history/2019-12-19 01%3A53%3A15.480800.hdf5 b/training_history/2019-12-19 01%3A53%3A15.480800.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..05a8764ec6b27ffa6284797028e98b2e4f3dc060 --- /dev/null +++ b/training_history/2019-12-19 01%3A53%3A15.480800.hdf5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6eedcbb8201def963e45e36721b9d23bb604c2f1a2c39331f80e0541187dd287 +size 26126696 diff --git a/training_history/2025-08-07_16-25-27.hdf5 b/training_history/2025-08-07_16-25-27.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..90356f735ba02cbcdf6eb0801b5623efce4fbe2f --- /dev/null +++ b/training_history/2025-08-07_16-25-27.hdf5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7cbd78f8b4dfeb43af1f6ba8167b2650acb4cec2cca7c8a3c71000dcfd63908 +size 26362704 diff --git a/training_history/2025-08-07_16-25-27.json b/training_history/2025-08-07_16-25-27.json new file mode 100644 index 0000000000000000000000000000000000000000..dd1b5a8e7e884a644657c9b483f2c952943e8068 --- /dev/null +++ b/training_history/2025-08-07_16-25-27.json @@ -0,0 +1,3946 @@ +{ + "loss": [ + 0.30647748708724976, + 0.1750967651605606, + 0.1348150670528412, + 0.10996372997760773, + 0.0961669534444809, + 0.08511081337928772, + 0.07905993610620499, + 0.07364209741353989, + 0.06904656440019608, + 0.0664418488740921, + 0.0637221485376358, + 0.06133847311139107, + 0.05951695144176483, + 0.05743888020515442, + 0.05550319701433182, + 0.05368385836482048, + 0.05071243271231651, + 0.053955163806676865, + 0.05913974717259407, + 0.05481218546628952, + 0.053216658532619476, + 0.051806069910526276, + 0.05040912330150604, + 0.04853715002536774, + 0.04627740755677223, + 0.044616397470235825, + 0.043374769389629364, + 0.04218907281756401, + 0.04144686087965965, + 0.04063361510634422, + 0.04707714170217514, + 0.04899223521351814, + 0.04533304274082184, + 0.043027348816394806, + 0.041465017944574356, + 0.040123045444488525, + 0.03773369640111923, + 0.03661240264773369, + 0.03559977933764458, + 0.034923210740089417, + 0.035055261105298996, + 0.03570236638188362, + 0.03435955196619034, + 0.03357187658548355, + 0.0329667367041111, + 0.03242901340126991, + 0.03186548501253128, + 0.03149624541401863, + 0.031010165810585022, + 0.03060271590948105, + 0.030267074704170227, + 0.03771795332431793, + 0.03534615784883499, + 0.033918071538209915, + 0.03289610520005226, + 0.032201554626226425, + 0.0314098559319973, + 0.03094932995736599, + 0.030538897961378098, + 0.029941964894533157, + 0.029633009806275368, + 0.02943609468638897, + 0.028660012409090996, + 0.028602395206689835, + 0.028005452826619148, + 0.02752712368965149, + 0.027221843600273132, + 0.02680036798119545, + 0.026418132707476616, + 0.02618495002388954, + 0.025956127792596817, + 0.02560962177813053, + 0.025426078587770462, + 0.02505660429596901, + 0.024684013798832893, + 0.024436764419078827, + 0.024140072986483574, + 0.023672718554735184, + 0.023578155785799026, + 0.023109160363674164, + 0.02283332496881485, + 0.02255292609333992, + 0.022325903177261353, + 0.021962279453873634, + 0.021811343729496002, + 0.02172011323273182, + 0.02129892073571682, + 0.021038774400949478, + 0.020723635330796242, + 0.02045297622680664, + 0.0214977003633976, + 0.021355675533413887, + 0.021544352173805237, + 0.020594464614987373, + 0.020267758518457413, + 0.021381760016083717, + 0.020386535674333572, + 0.019776318222284317, + 0.019427292048931122, + 0.019123896956443787, + 0.02659258246421814, + 0.03618019074201584, + 0.027225933969020844, + 0.025075973942875862, + 0.02390117757022381, + 0.023283500224351883, + 0.022889306768774986, + 0.022487878799438477, + 0.02206781506538391, + 0.021616896614432335, + 0.02290518395602703, + 0.02203931286931038, + 0.021298978477716446, + 0.02082144469022751, + 0.020476309582591057, + 0.020369667559862137, + 0.020366981625556946, + 0.020146096125245094, + 0.019704287871718407, + 0.01959238573908806, + 0.01938951388001442, + 0.019204184412956238, + 0.018987014889717102, + 0.018942462280392647, + 0.018814796581864357, + 0.0185571126639843, + 0.01836971566081047, + 0.018259204924106598, + 0.018205296248197556, + 0.018134118989109993, + 0.017913157120347023, + 0.017672620713710785, + 0.017590902745723724, + 0.017439354211091995, + 0.01727035455405712, + 0.017223920673131943, + 0.018535830080509186, + 0.01764857769012451, + 0.017293628305196762, + 0.017154958099126816, + 0.016896503046154976, + 0.016868583858013153, + 0.016872286796569824, + 0.016433054581284523, + 0.016220392659306526, + 0.01603561080992222, + 0.015778077766299248, + 0.01577850617468357, + 0.015843817964196205, + 0.015809863805770874, + 0.015860065817832947, + 0.015477662906050682, + 0.01531931571662426, + 0.015125595033168793, + 0.01502202544361353, + 0.014976686798036098, + 0.015033239498734474, + 0.015951665118336678, + 0.01540245022624731, + 0.014854921028017998, + 0.014961468987166882, + 0.015013501979410648, + 0.01517416164278984, + 0.01449889037758112, + 0.014303860254585743, + 0.014187045395374298, + 0.014804028905928135, + 0.015959901735186577, + 0.020786097273230553, + 0.01760375313460827, + 0.015779541805386543, + 0.015848513692617416, + 0.015986243262887, + 0.014567948877811432, + 0.013230666518211365, + 0.01453038863837719, + 0.013637007214128971, + 0.013031593523919582, + 0.01634998433291912, + 0.02193347178399563, + 0.016191652044653893, + 0.01508353091776371, + 0.015239519998431206, + 0.015201173722743988, + 0.014139600098133087, + 0.013658924028277397, + 0.013358766213059425, + 0.013011476956307888, + 0.012659418396651745, + 0.012393197976052761, + 0.012266334146261215, + 0.011977892369031906, + 0.011970474384725094, + 0.011703734286129475, + 0.01161114126443863, + 0.01172381266951561, + 0.011709300801157951, + 0.012194477021694183, + 0.013468567281961441, + 0.017431264743208885, + 0.015250993892550468, + 0.013966160826385021, + 0.012959204614162445, + 0.012426641769707203, + 0.012011424638330936, + 0.012029320932924747, + 0.012592227198183537, + 0.011803402565419674, + 0.011517093516886234, + 0.011947079561650753, + 0.012068829499185085, + 0.01242718007415533, + 0.012240450829267502, + 0.011861618608236313, + 0.012352905236184597, + 0.014822003431618214, + 0.019554194062948227, + 0.021305328235030174, + 0.019731786102056503, + 0.01814408227801323, + 0.01708049699664116, + 0.016473757103085518, + 0.016690775752067566, + 0.02231513150036335, + 0.01974739506840706, + 0.017199721187353134, + 0.016205912455916405, + 0.01561315730214119, + 0.01506352610886097, + 0.014831586740911007, + 0.014372607693076134, + 0.014147626236081123, + 0.013727042824029922, + 0.01352244708687067, + 0.01320627424865961, + 0.013022632338106632, + 0.012837468646466732, + 0.012507705949246883, + 0.01239242497831583, + 0.012630939483642578, + 0.012593026272952557, + 0.012325061485171318, + 0.011994406580924988, + 0.011810517869889736, + 0.012759208679199219, + 0.011985184624791145, + 0.011915159411728382, + 0.013752062804996967, + 0.01360330730676651, + 0.012521203607320786, + 0.012182645499706268, + 0.011708719655871391, + 0.01152396947145462, + 0.011431052349507809, + 0.01117919571697712, + 0.011751057580113411, + 0.011490956880152225, + 0.011127873323857784, + 0.010874041356146336, + 0.0110195092856884, + 0.01136202085763216, + 0.010967769660055637, + 0.011810488998889923, + 0.021257253363728523, + 0.014899597503244877, + 0.013231162913143635, + 0.01252503041177988, + 0.011877506040036678, + 0.011600157245993614, + 0.011508181691169739, + 0.011495032347738743, + 0.011990023776888847, + 0.011151742190122604, + 0.011049785651266575, + 0.010565276257693768, + 0.013197552412748337, + 0.01115148514509201, + 0.029514271765947342, + 0.0241073127835989, + 0.01918421871960163, + 0.017714563757181168, + 0.01667606830596924, + 0.015957504510879517, + 0.015396581031382084, + 0.014803696423768997, + 0.013729263097047806, + 0.013278660364449024, + 0.012890578247606754, + 0.012571698985993862, + 0.0128385741263628, + 0.012980268336832523, + 0.012489119544625282, + 0.012236335314810276, + 0.011874557472765446, + 0.011515924707055092, + 0.01126138586550951, + 0.011987834237515926, + 0.011403180658817291, + 0.011201435700058937, + 0.010884991846978664, + 0.010778272524476051, + 0.010546007193624973, + 0.010356712155044079, + 0.010204695165157318, + 0.010189410299062729, + 0.010785996913909912, + 0.011168578639626503, + 0.011988673359155655, + 0.013514974154531956, + 0.011950574815273285, + 0.015517404302954674, + 0.012958980165421963, + 0.012304777279496193, + 0.011725402437150478, + 0.01116760354489088, + 0.010796580463647842, + 0.011244084686040878, + 0.011185579001903534, + 0.010852617211639881, + 0.010554634034633636, + 0.010332541540265083, + 0.010245597921311855, + 0.010003700852394104, + 0.009881113655865192, + 0.009757373481988907, + 0.00969171617180109, + 0.00992400012910366, + 0.009720757603645325, + 0.009623757563531399, + 0.009466061368584633, + 0.009499983862042427, + 0.009426280856132507, + 0.009482208639383316, + 0.013372893445193768, + 0.021394873037934303, + 0.015852460637688637, + 0.014674946665763855, + 0.014396724291145802, + 0.01334143802523613, + 0.012487740255892277, + 0.011898113414645195, + 0.011559720151126385, + 0.011196346022188663, + 0.010986356064677238, + 0.010869316756725311, + 0.010591063648462296, + 0.01035912986844778, + 0.010249403305351734, + 0.010015510022640228, + 0.010098502971231937, + 0.010486401617527008, + 0.010132399387657642, + 0.010006209835410118, + 0.017435988411307335, + 0.017399724572896957, + 0.013613355346024036, + 0.01278015412390232, + 0.012173213995993137, + 0.011733791790902615, + 0.011255122721195221, + 0.010943155735731125, + 0.01070241816341877, + 0.010416466742753983, + 0.01040695607662201, + 0.010443135164678097, + 0.009866783395409584, + 0.009707099758088589, + 0.009678167290985584, + 0.009503369219601154, + 0.00959171261638403, + 0.009591631591320038, + 0.009851268492639065, + 0.00967353954911232, + 0.009600138291716576, + 0.03234858438372612, + 0.020381923764944077, + 0.017983023077249527, + 0.01669255830347538, + 0.015920067206025124, + 0.016266971826553345, + 0.014160492457449436, + 0.013295168988406658, + 0.012696294113993645, + 0.01237404253333807, + 0.01230611838400364, + 0.012054872699081898, + 0.011769881471991539, + 0.011479108594357967, + 0.0114207211881876, + 0.011093338951468468, + 0.010895829647779465, + 0.010720976628363132, + 0.011077550239861012, + 0.010639906860888004, + 0.010437337681651115, + 0.010433749295771122, + 0.010163159109652042, + 0.010187039151787758, + 0.009886456653475761, + 0.009900178760290146, + 0.009881138801574707, + 0.009746776893734932, + 0.009791776537895203, + 0.010411535389721394, + 0.010251682251691818, + 0.00973040796816349, + 0.014108603820204735, + 0.012864779680967331, + 0.011442308314144611, + 0.011878883466124535, + 0.01187659241259098, + 0.011125029996037483, + 0.010407320223748684, + 0.010011264123022556, + 0.019328519701957703, + 0.021474624052643776, + 0.01828955113887787, + 0.015556896105408669, + 0.014325638301670551, + 0.013566575944423676, + 0.012937955558300018, + 0.012533580884337425, + 0.012252584099769592, + 0.012734328396618366, + 0.012158636003732681, + 0.013013201765716076, + 0.011682095937430859, + 0.02510238066315651, + 0.018895352259278297, + 0.015548864379525185, + 0.014414222911000252, + 0.013655620627105236, + 0.013227966614067554, + 0.012928091920912266, + 0.01250703725963831, + 0.012381622567772865, + 0.01197823416441679, + 0.011751786805689335, + 0.01158289797604084, + 0.011270741000771523, + 0.011268547736108303, + 0.011189980432391167, + 0.011638487689197063, + 0.011096851900219917, + 0.010878192260861397, + 0.010670908726751804, + 0.010411555878818035, + 0.010200803168118, + 0.010321483016014099, + 0.010585369542241096, + 0.010327773168683052, + 0.010099920444190502, + 0.00985438097268343, + 0.009732011705636978, + 0.009680816903710365, + 0.009557006880640984, + 0.009599557146430016, + 0.009605410508811474, + 0.009399175643920898, + 0.009391609579324722, + 0.009294605813920498, + 0.00959872081875801, + 0.009268030524253845, + 0.009913635440170765, + 0.010886119678616524, + 0.010026099160313606, + 0.010056180879473686, + 0.009460094384849072, + 0.009112786501646042, + 0.024201123043894768, + 0.01934361457824707, + 0.013264663517475128, + 0.011855002492666245, + 0.011098966002464294, + 0.010679544880986214, + 0.012001180090010166, + 0.014273331500589848, + 0.013267090544104576, + 0.0121291633695364, + 0.011487667448818684, + 0.011008239351212978, + 0.010625632479786873, + 0.010453431867063046, + 0.010083537548780441, + 0.009872745722532272, + 0.009798270650207996, + 0.009803017601370811, + 0.009521342813968658, + 0.010055589489638805, + 0.009551863186061382, + 0.009455163963139057 + ], + "dice_coef": [ + 0.8271756768226624, + 0.8616277575492859, + 0.8727266192436218, + 0.8806064128875732, + 0.8851526379585266, + 0.8895143270492554, + 0.8929189443588257, + 0.8961383700370789, + 0.8982627987861633, + 0.9007869362831116, + 0.9030752778053284, + 0.9055670499801636, + 0.9069859385490417, + 0.9094352722167969, + 0.9106618165969849, + 0.9123764634132385, + 0.9143651723861694, + 0.9116659164428711, + 0.9102592468261719, + 0.914546549320221, + 0.9164059162139893, + 0.918033242225647, + 0.9189513921737671, + 0.9205445051193237, + 0.9218775033950806, + 0.9237352609634399, + 0.9249120354652405, + 0.9262439608573914, + 0.9274923205375671, + 0.9286374449729919, + 0.9185131788253784, + 0.9172464609146118, + 0.921724796295166, + 0.9248894453048706, + 0.9270980954170227, + 0.9278378486633301, + 0.9289146661758423, + 0.9303903579711914, + 0.9316245317459106, + 0.9325858950614929, + 0.9317297339439392, + 0.930961012840271, + 0.9337752461433411, + 0.935239851474762, + 0.9366147518157959, + 0.9373242259025574, + 0.9382657408714294, + 0.9389757513999939, + 0.9396278858184814, + 0.9405460953712463, + 0.9411396980285645, + 0.9279007315635681, + 0.9338310360908508, + 0.9359795451164246, + 0.9377782940864563, + 0.9386602640151978, + 0.9395512342453003, + 0.9404653310775757, + 0.9410924315452576, + 0.9420122504234314, + 0.9424967169761658, + 0.9431310296058655, + 0.9438883662223816, + 0.9442669749259949, + 0.9446057081222534, + 0.9453078508377075, + 0.946151852607727, + 0.9468348622322083, + 0.9474738240242004, + 0.9479974508285522, + 0.9481825232505798, + 0.9486918449401855, + 0.9491387009620667, + 0.9495266079902649, + 0.9503999948501587, + 0.9512287378311157, + 0.9517524838447571, + 0.9522303938865662, + 0.9523171186447144, + 0.9529317617416382, + 0.9539192914962769, + 0.954371988773346, + 0.9549437165260315, + 0.9555351138114929, + 0.9559608697891235, + 0.9559311866760254, + 0.9569058418273926, + 0.9574870467185974, + 0.9579500555992126, + 0.9585270285606384, + 0.9570472240447998, + 0.9570488929748535, + 0.9575705528259277, + 0.9589021801948547, + 0.9596315026283264, + 0.9579778909683228, + 0.9597327709197998, + 0.9607507586479187, + 0.9615135788917542, + 0.9620363116264343, + 0.9527720212936401, + 0.9445807933807373, + 0.9498860239982605, + 0.9536462426185608, + 0.9545617699623108, + 0.9558765888214111, + 0.9553714990615845, + 0.956228494644165, + 0.9569122791290283, + 0.9576302170753479, + 0.9567261338233948, + 0.957889974117279, + 0.9587068557739258, + 0.9591506719589233, + 0.9596977233886719, + 0.9597998261451721, + 0.9601463079452515, + 0.9599983096122742, + 0.960551917552948, + 0.9609813094139099, + 0.961475670337677, + 0.9618078470230103, + 0.9621520042419434, + 0.9621374607086182, + 0.962533712387085, + 0.9629316329956055, + 0.963316023349762, + 0.9636129140853882, + 0.9639310836791992, + 0.9640538692474365, + 0.9640945792198181, + 0.9645101428031921, + 0.9647799730300903, + 0.965129554271698, + 0.9654141664505005, + 0.9656355381011963, + 0.9649770855903625, + 0.9656959176063538, + 0.9662606716156006, + 0.966301441192627, + 0.9666382074356079, + 0.9668326377868652, + 0.9668439030647278, + 0.9674142599105835, + 0.9679332375526428, + 0.9681134819984436, + 0.9686338901519775, + 0.9687585234642029, + 0.968749463558197, + 0.9687435626983643, + 0.9692564606666565, + 0.9698631763458252, + 0.9701429009437561, + 0.9703571200370789, + 0.9706502556800842, + 0.9709163904190063, + 0.9707404375076294, + 0.9700429439544678, + 0.97032630443573, + 0.9712384343147278, + 0.9708642959594727, + 0.9709155559539795, + 0.9712903499603271, + 0.9723801016807556, + 0.972836971282959, + 0.9731794595718384, + 0.9727415442466736, + 0.9710344076156616, + 0.9692134261131287, + 0.9714216589927673, + 0.9724838137626648, + 0.9720187187194824, + 0.9717462658882141, + 0.972689688205719, + 0.9736983776092529, + 0.9719834923744202, + 0.9729598164558411, + 0.9742265343666077, + 0.9694671034812927, + 0.9642630815505981, + 0.9688175916671753, + 0.9700709581375122, + 0.9708917140960693, + 0.970836341381073, + 0.9719638228416443, + 0.9728350043296814, + 0.9731382727622986, + 0.9736285209655762, + 0.9742369651794434, + 0.9747639894485474, + 0.9749345779418945, + 0.9753822088241577, + 0.9758213758468628, + 0.9761431217193604, + 0.976209282875061, + 0.9762759804725647, + 0.9761236310005188, + 0.9762524366378784, + 0.973239004611969, + 0.9702832698822021, + 0.9724173545837402, + 0.9745905995368958, + 0.9755451083183289, + 0.9759634733200073, + 0.9766412973403931, + 0.976182222366333, + 0.9765947461128235, + 0.977360188961029, + 0.977577269077301, + 0.9766439199447632, + 0.976081132888794, + 0.9756816625595093, + 0.9761340022087097, + 0.9771096110343933, + 0.9764305353164673, + 0.9725992679595947, + 0.9646629691123962, + 0.9655504822731018, + 0.9679058790206909, + 0.9697939157485962, + 0.9714179039001465, + 0.9719633460044861, + 0.9701056480407715, + 0.959028959274292, + 0.9631887674331665, + 0.9672734141349792, + 0.9687532782554626, + 0.9695519208908081, + 0.9703459739685059, + 0.970940887928009, + 0.9715447425842285, + 0.9717562198638916, + 0.9725369811058044, + 0.9730168581008911, + 0.9734640717506409, + 0.9738755226135254, + 0.9740811586380005, + 0.9746923446655273, + 0.9748530983924866, + 0.9743151068687439, + 0.9750608205795288, + 0.9753448963165283, + 0.9757680892944336, + 0.9758749604225159, + 0.9755315184593201, + 0.9764375686645508, + 0.9760685563087463, + 0.9738132953643799, + 0.9744699001312256, + 0.9759675860404968, + 0.9765239953994751, + 0.9767493009567261, + 0.9770404100418091, + 0.9775190949440002, + 0.9777635335922241, + 0.9770656824111938, + 0.9773379564285278, + 0.9780363440513611, + 0.9783704280853271, + 0.9783034324645996, + 0.9781065583229065, + 0.9786196351051331, + 0.9776173233985901, + 0.968602180480957, + 0.9737133383750916, + 0.975898265838623, + 0.9767024517059326, + 0.9774192571640015, + 0.9777238368988037, + 0.9777422547340393, + 0.977294385433197, + 0.9767982363700867, + 0.9779852032661438, + 0.9782487154006958, + 0.9789919257164001, + 0.9749283194541931, + 0.9774854779243469, + 0.957599401473999, + 0.9600788354873657, + 0.9646042585372925, + 0.9667049646377563, + 0.9683660864830017, + 0.9693924784660339, + 0.9704417586326599, + 0.9712101221084595, + 0.9723083972930908, + 0.9729383587837219, + 0.9738287925720215, + 0.9742386341094971, + 0.9735625386238098, + 0.9744496941566467, + 0.9751554131507874, + 0.9753341674804688, + 0.9757531881332397, + 0.9760571718215942, + 0.9765751361846924, + 0.9760973453521729, + 0.9768602848052979, + 0.9770902991294861, + 0.9773600697517395, + 0.9775993227958679, + 0.9781467318534851, + 0.9784204959869385, + 0.9787243008613586, + 0.9789025783538818, + 0.9783837199211121, + 0.9773847460746765, + 0.9759023785591125, + 0.9755476713180542, + 0.977429986000061, + 0.9732048511505127, + 0.9762001037597656, + 0.9772343039512634, + 0.9779145121574402, + 0.9783735275268555, + 0.9787331819534302, + 0.9782876968383789, + 0.9788635969161987, + 0.9793126583099365, + 0.9796106815338135, + 0.9797661304473877, + 0.9799916744232178, + 0.9803319573402405, + 0.9804670214653015, + 0.9807540774345398, + 0.980849027633667, + 0.9807034730911255, + 0.9808654189109802, + 0.9810879826545715, + 0.9812144637107849, + 0.9812115430831909, + 0.9814638495445251, + 0.981597363948822, + 0.9785327315330505, + 0.9670518636703491, + 0.9723511338233948, + 0.973235011100769, + 0.9736719727516174, + 0.9750344157218933, + 0.9761777520179749, + 0.9768736958503723, + 0.9774056077003479, + 0.9779927730560303, + 0.9782809019088745, + 0.9787391424179077, + 0.9790400266647339, + 0.9794601798057556, + 0.9793281555175781, + 0.9797995090484619, + 0.9799940586090088, + 0.9790902137756348, + 0.9799678921699524, + 0.9799731373786926, + 0.9702537655830383, + 0.9701027274131775, + 0.9748572111129761, + 0.9759320020675659, + 0.976621687412262, + 0.9774580597877502, + 0.977945864200592, + 0.9785280823707581, + 0.978829562664032, + 0.9791204333305359, + 0.9791039228439331, + 0.9786739945411682, + 0.9796222448348999, + 0.9800112247467041, + 0.9802066683769226, + 0.9802090525627136, + 0.9800991415977478, + 0.980035126209259, + 0.9800930619239807, + 0.9805180430412292, + 0.9804971814155579, + 0.9669134020805359, + 0.9731391668319702, + 0.9752886295318604, + 0.9763974547386169, + 0.9770158529281616, + 0.9739793539047241, + 0.97658771276474, + 0.9770677089691162, + 0.9779804944992065, + 0.9784687161445618, + 0.9784383773803711, + 0.9788373708724976, + 0.9792552590370178, + 0.979379415512085, + 0.979070782661438, + 0.9796465635299683, + 0.9798751473426819, + 0.9802157878875732, + 0.9796494245529175, + 0.980462372303009, + 0.980598509311676, + 0.9807146787643433, + 0.9805933833122253, + 0.9807178378105164, + 0.9807981252670288, + 0.9807450175285339, + 0.9807879328727722, + 0.9811874628067017, + 0.9808743596076965, + 0.9802898168563843, + 0.9803329706192017, + 0.9811297059059143, + 0.976192057132721, + 0.9774067997932434, + 0.9790591597557068, + 0.9782618880271912, + 0.978374719619751, + 0.9797973036766052, + 0.98039710521698, + 0.9809061288833618, + 0.9706665277481079, + 0.9683108925819397, + 0.9707096219062805, + 0.9731748104095459, + 0.9748567342758179, + 0.9759149551391602, + 0.9767023324966431, + 0.9775484204292297, + 0.9776878952980042, + 0.9763371348381042, + 0.9780442118644714, + 0.9775347113609314, + 0.9785792231559753, + 0.9640551209449768, + 0.9695783853530884, + 0.9730770587921143, + 0.9742680788040161, + 0.9752517938613892, + 0.9759684205055237, + 0.976382315158844, + 0.9770227074623108, + 0.9772208333015442, + 0.9777687191963196, + 0.9781318306922913, + 0.9782090187072754, + 0.9786636233329773, + 0.9788519144058228, + 0.9788020253181458, + 0.9782240390777588, + 0.9789761900901794, + 0.9791821241378784, + 0.9795763492584229, + 0.979709267616272, + 0.9801487326622009, + 0.9801005125045776, + 0.9800381660461426, + 0.9801545143127441, + 0.9805911779403687, + 0.980867862701416, + 0.981088399887085, + 0.9812235832214355, + 0.9811208844184875, + 0.981157124042511, + 0.9814913272857666, + 0.9817042350769043, + 0.9816209077835083, + 0.9817667007446289, + 0.9815982580184937, + 0.982085645198822, + 0.9803016185760498, + 0.9796533584594727, + 0.9806585907936096, + 0.9811287522315979, + 0.9816910624504089, + 0.9821954369544983, + 0.9700067639350891, + 0.9723506569862366, + 0.9768332839012146, + 0.9787658452987671, + 0.9797025918960571, + 0.9802789688110352, + 0.9781936407089233, + 0.9750186800956726, + 0.9768596291542053, + 0.9784777164459229, + 0.9793055057525635, + 0.9798623919487, + 0.9803575277328491, + 0.9806535840034485, + 0.981094479560852, + 0.9813281893730164, + 0.9809693098068237, + 0.9815014600753784, + 0.9817039966583252, + 0.9812902212142944, + 0.9819175601005554, + 0.9820971488952637 + ], + "precision": [ + 0.8992838859558105, + 0.8935195803642273, + 0.9231774806976318, + 0.9065283536911011, + 0.9022323489189148, + 0.9007483124732971, + 0.9058020114898682, + 0.9094322323799133, + 0.9126219749450684, + 0.9185943007469177, + 0.9219158291816711, + 0.9222309589385986, + 0.9274000525474548, + 0.9257815480232239, + 0.9265021085739136, + 0.9240222573280334, + 0.9306662678718567, + 0.9135439395904541, + 0.9231010675430298, + 0.9295929670333862, + 0.9293330907821655, + 0.9312353134155273, + 0.9329625964164734, + 0.9316965341567993, + 0.9333266615867615, + 0.9345052242279053, + 0.9335157871246338, + 0.9346805214881897, + 0.9368788003921509, + 0.9368543028831482, + 0.9517249464988708, + 0.95307457447052, + 0.9403975009918213, + 0.9448124766349792, + 0.9445593953132629, + 0.9379482865333557, + 0.9392950534820557, + 0.9407076239585876, + 0.9420888423919678, + 0.9438381791114807, + 0.9501326680183411, + 0.9543547630310059, + 0.949072003364563, + 0.9473968148231506, + 0.9440956711769104, + 0.9430425763130188, + 0.9459739327430725, + 0.9455015659332275, + 0.9466095566749573, + 0.9462694525718689, + 0.9438596367835999, + 0.9619185924530029, + 0.9508318901062012, + 0.9483709335327148, + 0.9470229744911194, + 0.946711003780365, + 0.948207437992096, + 0.9481415748596191, + 0.9484500288963318, + 0.9477959275245667, + 0.9477677345275879, + 0.9469801187515259, + 0.9493097066879272, + 0.9443789720535278, + 0.9479587078094482, + 0.9503016471862793, + 0.9498686194419861, + 0.9504828453063965, + 0.9518103003501892, + 0.9503151774406433, + 0.9508251547813416, + 0.9512123465538025, + 0.9516437649726868, + 0.9551058411598206, + 0.9545681476593018, + 0.9550942778587341, + 0.9535008668899536, + 0.9537196159362793, + 0.9563971161842346, + 0.9570322632789612, + 0.9570552110671997, + 0.9566587209701538, + 0.9578317999839783, + 0.959659993648529, + 0.9594024419784546, + 0.9616206884384155, + 0.960929274559021, + 0.9598709344863892, + 0.9609968066215515, + 0.96136873960495, + 0.9600357413291931, + 0.9573716521263123, + 0.9603555202484131, + 0.9617785215377808, + 0.9621409773826599, + 0.9566047191619873, + 0.9613975882530212, + 0.9627871513366699, + 0.9628319144248962, + 0.9638941287994385, + 0.9425920248031616, + 0.9338929653167725, + 0.9540751576423645, + 0.9525136351585388, + 0.9575161337852478, + 0.9593470692634583, + 0.9632865786552429, + 0.9643735289573669, + 0.9636847972869873, + 0.9605234861373901, + 0.9628630876541138, + 0.9579678177833557, + 0.9624496102333069, + 0.9615869522094727, + 0.9628499150276184, + 0.9642650485038757, + 0.9617869853973389, + 0.9646812081336975, + 0.9648491740226746, + 0.962059736251831, + 0.9631932973861694, + 0.9624291062355042, + 0.9625061750411987, + 0.9661656618118286, + 0.9649047255516052, + 0.9643188714981079, + 0.9657419919967651, + 0.9640787839889526, + 0.9652308225631714, + 0.9639834761619568, + 0.9667556285858154, + 0.96665358543396, + 0.9679520130157471, + 0.9679667353630066, + 0.9676694273948669, + 0.9669484496116638, + 0.9660540223121643, + 0.9672194719314575, + 0.9675654172897339, + 0.9686472415924072, + 0.9677742719650269, + 0.9648299217224121, + 0.9687601327896118, + 0.9699457883834839, + 0.9697255492210388, + 0.9699035286903381, + 0.9702296853065491, + 0.971717894077301, + 0.9693904519081116, + 0.9700163006782532, + 0.9713046550750732, + 0.9716519713401794, + 0.9721776247024536, + 0.9729169011116028, + 0.9725967049598694, + 0.9724439382553101, + 0.9721809029579163, + 0.968743622303009, + 0.9726381897926331, + 0.9733038544654846, + 0.975723385810852, + 0.9754957556724548, + 0.9738695621490479, + 0.973743200302124, + 0.9743053317070007, + 0.9736654162406921, + 0.9750998616218567, + 0.9715243577957153, + 0.9721106290817261, + 0.972126305103302, + 0.9748666286468506, + 0.9704681634902954, + 0.9725722074508667, + 0.9727873802185059, + 0.9750034213066101, + 0.9773173332214355, + 0.9747646450996399, + 0.9758877158164978, + 0.9750096797943115, + 0.9615429043769836, + 0.9669661521911621, + 0.9760112166404724, + 0.9738235473632812, + 0.9678298830986023, + 0.9739218950271606, + 0.9718801379203796, + 0.9750884175300598, + 0.9742626547813416, + 0.9749855399131775, + 0.9757564067840576, + 0.9747680425643921, + 0.9766591191291809, + 0.9760065674781799, + 0.9765374064445496, + 0.978253960609436, + 0.9755191206932068, + 0.9758433103561401, + 0.9768528938293457, + 0.9819539189338684, + 0.9725168347358704, + 0.9777063131332397, + 0.9763947129249573, + 0.9767307043075562, + 0.9769740700721741, + 0.9780659675598145, + 0.9795251488685608, + 0.9774302244186401, + 0.9785761833190918, + 0.9788815379142761, + 0.9764156937599182, + 0.9729133248329163, + 0.9722380042076111, + 0.9785850644111633, + 0.9786012768745422, + 0.9787552356719971, + 0.9730265736579895, + 0.9583626985549927, + 0.9634031653404236, + 0.9660843014717102, + 0.9704118371009827, + 0.9707321524620056, + 0.9709881544113159, + 0.9760051369667053, + 0.946617066860199, + 0.9593662619590759, + 0.9677178859710693, + 0.9701368808746338, + 0.9732000231742859, + 0.9716952443122864, + 0.9724082946777344, + 0.9722424149513245, + 0.9704092741012573, + 0.9729465842247009, + 0.9730241298675537, + 0.9735832214355469, + 0.9737854599952698, + 0.9745733737945557, + 0.9750858545303345, + 0.9747254252433777, + 0.9785636067390442, + 0.9759612083435059, + 0.9760167598724365, + 0.9758936762809753, + 0.9770054221153259, + 0.9762231707572937, + 0.9768044948577881, + 0.9787408113479614, + 0.9720087647438049, + 0.9759089946746826, + 0.976744532585144, + 0.9771106243133545, + 0.975960373878479, + 0.9786209464073181, + 0.9778503179550171, + 0.978689968585968, + 0.9750562310218811, + 0.9783962965011597, + 0.9785009026527405, + 0.9785699248313904, + 0.9773379564285278, + 0.9788073897361755, + 0.9803115129470825, + 0.9762491583824158, + 0.9689900875091553, + 0.9725742936134338, + 0.9754883050918579, + 0.9774740934371948, + 0.9774919748306274, + 0.9788755774497986, + 0.9765322804450989, + 0.9788328409194946, + 0.9757962822914124, + 0.9774616956710815, + 0.9787426590919495, + 0.9791474938392639, + 0.9792768359184265, + 0.9798372983932495, + 0.9480541944503784, + 0.9694900512695312, + 0.9659545421600342, + 0.9701856970787048, + 0.9682232737541199, + 0.9694339036941528, + 0.9717342853546143, + 0.9711283445358276, + 0.9724385142326355, + 0.9730356931686401, + 0.9737069010734558, + 0.9744369387626648, + 0.9692203998565674, + 0.9734945893287659, + 0.974539577960968, + 0.9749677181243896, + 0.9785441756248474, + 0.9759414792060852, + 0.9767823219299316, + 0.9748192429542542, + 0.9773659110069275, + 0.9795023798942566, + 0.9782940149307251, + 0.9800353646278381, + 0.978857696056366, + 0.9792131185531616, + 0.9789946675300598, + 0.9792091846466064, + 0.9785779118537903, + 0.9779298305511475, + 0.9727939367294312, + 0.9799074530601501, + 0.9778013825416565, + 0.9810697436332703, + 0.9760221242904663, + 0.9775710701942444, + 0.9781808257102966, + 0.9784017205238342, + 0.9790849685668945, + 0.975799024105072, + 0.9790262579917908, + 0.9796527028083801, + 0.9795625805854797, + 0.9797856211662292, + 0.9802121520042419, + 0.9808707237243652, + 0.9800721406936646, + 0.9813472628593445, + 0.9816632866859436, + 0.979587972164154, + 0.9817426204681396, + 0.9810745716094971, + 0.9811837673187256, + 0.9822655916213989, + 0.9821571707725525, + 0.981706440448761, + 0.9849334359169006, + 0.9618869423866272, + 0.972809910774231, + 0.9725103378295898, + 0.9719271659851074, + 0.9765286445617676, + 0.9767768979072571, + 0.9780049324035645, + 0.9781366586685181, + 0.9787758588790894, + 0.979574978351593, + 0.9788941144943237, + 0.9799464344978333, + 0.9802626371383667, + 0.9812312722206116, + 0.9805123805999756, + 0.9803711175918579, + 0.9846120476722717, + 0.9806753396987915, + 0.9805695414543152, + 0.9805625677108765, + 0.9703459739685059, + 0.9786406755447388, + 0.9758301377296448, + 0.9795531034469604, + 0.9787594676017761, + 0.9793882966041565, + 0.9798562526702881, + 0.9799196124076843, + 0.9809191823005676, + 0.9783082604408264, + 0.9760771989822388, + 0.9804887771606445, + 0.9809110760688782, + 0.9809779524803162, + 0.9790912866592407, + 0.979170024394989, + 0.9827489852905273, + 0.9835047125816345, + 0.9825785160064697, + 0.9837331771850586, + 0.9735047221183777, + 0.972905695438385, + 0.977183997631073, + 0.978089451789856, + 0.9775056838989258, + 0.9682573080062866, + 0.976538360118866, + 0.9770686030387878, + 0.9791834354400635, + 0.9798524975776672, + 0.9826685786247253, + 0.9796350002288818, + 0.9802651405334473, + 0.9800410270690918, + 0.9790262579917908, + 0.9812719821929932, + 0.9796724319458008, + 0.9810829162597656, + 0.9772170186042786, + 0.9808533191680908, + 0.9819177389144897, + 0.981721818447113, + 0.9786110520362854, + 0.9818451404571533, + 0.9833516478538513, + 0.978785514831543, + 0.9799115657806396, + 0.9822924733161926, + 0.9824753403663635, + 0.9776352047920227, + 0.9810264706611633, + 0.9814620614051819, + 0.9740288257598877, + 0.9787504076957703, + 0.9795571565628052, + 0.9814469814300537, + 0.9769492149353027, + 0.9803171157836914, + 0.9813071489334106, + 0.9816375970840454, + 0.9820874929428101, + 0.9740938544273376, + 0.9657480716705322, + 0.9762830138206482, + 0.9772912263870239, + 0.9772639274597168, + 0.9782102704048157, + 0.978047251701355, + 0.9793410897254944, + 0.9829915165901184, + 0.977618396282196, + 0.9794162511825562, + 0.9780856370925903, + 0.9457417726516724, + 0.9645279049873352, + 0.9745407104492188, + 0.9753058552742004, + 0.9752920866012573, + 0.975785493850708, + 0.9767656922340393, + 0.9786217212677002, + 0.9787272214889526, + 0.9782143831253052, + 0.978548526763916, + 0.9797430634498596, + 0.9783796668052673, + 0.9788793921470642, + 0.9774158000946045, + 0.977510929107666, + 0.9794684052467346, + 0.9799519181251526, + 0.9808859825134277, + 0.980858564376831, + 0.9806353449821472, + 0.9798208475112915, + 0.9801185131072998, + 0.9811465740203857, + 0.980260968208313, + 0.9816108345985413, + 0.9817655682563782, + 0.9815741777420044, + 0.9835969805717468, + 0.9819722175598145, + 0.982279360294342, + 0.9823079109191895, + 0.9823514819145203, + 0.9826764464378357, + 0.9827845096588135, + 0.9827767610549927, + 0.9799264669418335, + 0.9798104763031006, + 0.9804279208183289, + 0.9802714586257935, + 0.9820721745491028, + 0.9825060963630676, + 0.9860026836395264, + 0.9734170436859131, + 0.9770139455795288, + 0.9802051782608032, + 0.9802567958831787, + 0.9813519716262817, + 0.9743576645851135, + 0.9799386858940125, + 0.9786694049835205, + 0.9799676537513733, + 0.9810048937797546, + 0.980069637298584, + 0.9816738963127136, + 0.9833455681800842, + 0.981890082359314, + 0.9814645051956177, + 0.9829424619674683, + 0.9820736050605774, + 0.9824276566505432, + 0.9807500243186951, + 0.9814831018447876, + 0.9816710948944092 + ], + "recall": [ + 0.791861355304718, + 0.8451449275016785, + 0.8360539078712463, + 0.8644647002220154, + 0.8764495253562927, + 0.8856896758079529, + 0.886893630027771, + 0.8890292644500732, + 0.8897004127502441, + 0.8884459137916565, + 0.8893483281135559, + 0.8935720324516296, + 0.891123354434967, + 0.897046685218811, + 0.8985677361488342, + 0.904116690158844, + 0.9012035131454468, + 0.9133350253105164, + 0.9016810655593872, + 0.9030669331550598, + 0.9066722393035889, + 0.9079222679138184, + 0.9079533815383911, + 0.9120744466781616, + 0.9129979014396667, + 0.9152318835258484, + 0.9182695746421814, + 0.9195649027824402, + 0.9198611378669739, + 0.9220082759857178, + 0.8910340070724487, + 0.885751485824585, + 0.905456006526947, + 0.9070963263511658, + 0.9114834666252136, + 0.9191790819168091, + 0.9199527502059937, + 0.921348512172699, + 0.9224293231964111, + 0.9226288795471191, + 0.9153871536254883, + 0.9096200466156006, + 0.9198392629623413, + 0.9242357611656189, + 0.9300641417503357, + 0.9324789643287659, + 0.9314627051353455, + 0.9332801699638367, + 0.9335008263587952, + 0.9355972409248352, + 0.9391127824783325, + 0.8976917862892151, + 0.9181881546974182, + 0.9246246218681335, + 0.9293662905693054, + 0.9313563108444214, + 0.9316253662109375, + 0.9334909319877625, + 0.9344278573989868, + 0.9367935061454773, + 0.9377925395965576, + 0.9398112297058105, + 0.939002513885498, + 0.9445978999137878, + 0.9417274594306946, + 0.9408044219017029, + 0.9428611397743225, + 0.9436261057853699, + 0.9435579776763916, + 0.9460351467132568, + 0.9459415674209595, + 0.9465610384941101, + 0.947054386138916, + 0.944445788860321, + 0.9465948343276978, + 0.9477043747901917, + 0.9503015875816345, + 0.9510202407836914, + 0.9485911726951599, + 0.9491636753082275, + 0.9510773420333862, + 0.9523375630378723, + 0.9523199200630188, + 0.9516726732254028, + 0.9527785778045654, + 0.9506001472473145, + 0.9531168341636658, + 0.9553090929985046, + 0.9551175832748413, + 0.9558740854263306, + 0.954255223274231, + 0.9570053219795227, + 0.9550737738609314, + 0.9562565684318542, + 0.9573289752006531, + 0.9596896767616272, + 0.9582734704017639, + 0.9588923454284668, + 0.9603589773178101, + 0.9603280425071716, + 0.9656360745429993, + 0.9568237662315369, + 0.9465411305427551, + 0.9552591443061829, + 0.9520728588104248, + 0.9527888894081116, + 0.9479497075080872, + 0.9485549330711365, + 0.950512170791626, + 0.9550092220306396, + 0.9510374665260315, + 0.9581052660942078, + 0.9552298188209534, + 0.9569394588470459, + 0.9567636847496033, + 0.9556072354316711, + 0.9587507247924805, + 0.9555803537368774, + 0.9565052390098572, + 0.9600802659988403, + 0.9599339365959167, + 0.961341381072998, + 0.961948812007904, + 0.9583158493041992, + 0.9603102803230286, + 0.961682140827179, + 0.9610201716423035, + 0.9632683396339417, + 0.9627686738967896, + 0.9642438292503357, + 0.9615676999092102, + 0.9624653458595276, + 0.9617192149162292, + 0.9623971581459045, + 0.9632579684257507, + 0.9644113183021545, + 0.9640793204307556, + 0.9643028378486633, + 0.965067982673645, + 0.9640759229660034, + 0.9656127095222473, + 0.9689431190490723, + 0.9650319218635559, + 0.9649828672409058, + 0.9662241339683533, + 0.9663980603218079, + 0.9671016931533813, + 0.9658815860748291, + 0.9681800603866577, + 0.967562198638916, + 0.9672921895980835, + 0.9681373238563538, + 0.9681686758995056, + 0.967867910861969, + 0.9687545299530029, + 0.9694454669952393, + 0.9693561792373657, + 0.9714493751525879, + 0.9680829048156738, + 0.9692250490188599, + 0.9661319851875305, + 0.9665122628211975, + 0.9688233137130737, + 0.9710770845413208, + 0.9714246392250061, + 0.9727378487586975, + 0.9704562425613403, + 0.9707130789756775, + 0.9669188857078552, + 0.9709378480911255, + 0.9702185988426208, + 0.9736852645874023, + 0.9710268974304199, + 0.9726718664169312, + 0.9724496603012085, + 0.9668371081352234, + 0.9712908267974854, + 0.9726231694221497, + 0.9642851948738098, + 0.9676161408424377, + 0.9709209203720093, + 0.9643153548240662, + 0.9681193232536316, + 0.9739947319030762, + 0.9700707793235779, + 0.9738551378250122, + 0.971240222454071, + 0.9730367660522461, + 0.9735251665115356, + 0.973804771900177, + 0.9751293659210205, + 0.9741354584693909, + 0.9756639003753662, + 0.9757745265960693, + 0.974200963973999, + 0.9770787954330444, + 0.9764552116394043, + 0.9756947755813599, + 0.9650297164916992, + 0.9684521555900574, + 0.9673844575881958, + 0.972891092300415, + 0.9744333624839783, + 0.9750074148178101, + 0.9752647280693054, + 0.9729223251342773, + 0.9758098721504211, + 0.9761912822723389, + 0.9763118028640747, + 0.9769499897956848, + 0.9794024229049683, + 0.9793176054954529, + 0.9737824201583862, + 0.9756699800491333, + 0.9742213487625122, + 0.9722992181777954, + 0.9717274904251099, + 0.9682989716529846, + 0.9701378345489502, + 0.9694706201553345, + 0.9723451733589172, + 0.9731424450874329, + 0.9645086526870728, + 0.9728704690933228, + 0.9672322869300842, + 0.9669365882873535, + 0.9674783945083618, + 0.9659968018531799, + 0.9690768122673035, + 0.9695566892623901, + 0.9709073305130005, + 0.9731854796409607, + 0.9721854329109192, + 0.9730652570724487, + 0.9733920693397522, + 0.9740149974822998, + 0.9736261367797852, + 0.9743325114250183, + 0.9750107526779175, + 0.9701613783836365, + 0.9741973876953125, + 0.974717915058136, + 0.9756714701652527, + 0.9747783541679382, + 0.9748955368995667, + 0.9760963916778564, + 0.9734497666358948, + 0.975731372833252, + 0.97313392162323, + 0.9752202033996582, + 0.9759657979011536, + 0.9775606393814087, + 0.975492537021637, + 0.9772050976753235, + 0.9768574237823486, + 0.9791225790977478, + 0.9763096570968628, + 0.9775936603546143, + 0.9781891107559204, + 0.9793063402175903, + 0.9774335622787476, + 0.9769541621208191, + 0.9790760278701782, + 0.9690549373626709, + 0.9751819968223572, + 0.9764674305915833, + 0.9760426878929138, + 0.9774177670478821, + 0.9766298532485962, + 0.9790302515029907, + 0.9758355021476746, + 0.9779002666473389, + 0.9785527586936951, + 0.9777949452400208, + 0.9788627028465271, + 0.9708422422409058, + 0.9752458930015564, + 0.9690341353416443, + 0.9516607522964478, + 0.9637950658798218, + 0.9636418223381042, + 0.9688300490379333, + 0.9696590900421143, + 0.9694294929504395, + 0.971558690071106, + 0.9723370671272278, + 0.9729266166687012, + 0.9740192890167236, + 0.9741011261940002, + 0.9779934287071228, + 0.9755014181137085, + 0.9758363366127014, + 0.9757565259933472, + 0.9730213284492493, + 0.976201057434082, + 0.9763962030410767, + 0.9774160981178284, + 0.9763848185539246, + 0.9747228026390076, + 0.9764521718025208, + 0.9752063155174255, + 0.9774546027183533, + 0.9776479005813599, + 0.9784721732139587, + 0.9786161780357361, + 0.978244960308075, + 0.9769057035446167, + 0.9791543483734131, + 0.9713692665100098, + 0.9771057367324829, + 0.9656568765640259, + 0.976421594619751, + 0.9769189953804016, + 0.9776700139045715, + 0.9783620238304138, + 0.9783968925476074, + 0.9808138608932495, + 0.9787189960479736, + 0.9789877533912659, + 0.9796712398529053, + 0.9797583818435669, + 0.9797834753990173, + 0.9798057675361633, + 0.9808721542358398, + 0.9801729321479797, + 0.9800459146499634, + 0.9818342328071594, + 0.9800011515617371, + 0.9811092019081116, + 0.9812566637992859, + 0.9801689386367798, + 0.9807889461517334, + 0.9815011620521545, + 0.9724181890487671, + 0.9728742241859436, + 0.9720596671104431, + 0.9741119146347046, + 0.9755633473396301, + 0.9736440181732178, + 0.9756592512130737, + 0.9758150577545166, + 0.9767378568649292, + 0.9772686958312988, + 0.9770384430885315, + 0.9786340594291687, + 0.9781789183616638, + 0.9786956310272217, + 0.9774749279022217, + 0.9791203737258911, + 0.979654848575592, + 0.9737077355384827, + 0.979299783706665, + 0.979436993598938, + 0.9604448676109314, + 0.9702677726745605, + 0.9712108969688416, + 0.9761053323745728, + 0.9737813472747803, + 0.9762065410614014, + 0.9765533208847046, + 0.9772399663925171, + 0.9777761101722717, + 0.9773626327514648, + 0.9799520969390869, + 0.9813448190689087, + 0.9787777066230774, + 0.9791331887245178, + 0.9794602394104004, + 0.9813597798347473, + 0.9810588359832764, + 0.9773743152618408, + 0.976746678352356, + 0.9784935116767883, + 0.9773215055465698, + 0.9610670208930969, + 0.973602294921875, + 0.973552942276001, + 0.9748286008834839, + 0.9766324162483215, + 0.9799429178237915, + 0.9767052531242371, + 0.9771280884742737, + 0.9768232703208923, + 0.977131187915802, + 0.9742915034294128, + 0.9781007766723633, + 0.9782840609550476, + 0.9787495732307434, + 0.9791613817214966, + 0.9780554175376892, + 0.9801104068756104, + 0.979373574256897, + 0.9821428060531616, + 0.9800913333892822, + 0.9793018698692322, + 0.9797273874282837, + 0.9826022982597351, + 0.9796127080917358, + 0.9782849550247192, + 0.9827302694320679, + 0.9816903471946716, + 0.9800970554351807, + 0.9793035387992859, + 0.9830003380775452, + 0.9796827435493469, + 0.9808118939399719, + 0.9785438776016235, + 0.9761611223220825, + 0.9786279201507568, + 0.9751942157745361, + 0.9798658490180969, + 0.9793126583099365, + 0.9795118570327759, + 0.9801921248435974, + 0.9600467085838318, + 0.9629004001617432, + 0.9759656190872192, + 0.9701969623565674, + 0.972531259059906, + 0.9746580123901367, + 0.975263237953186, + 0.9770999550819397, + 0.9760912656784058, + 0.9698919057846069, + 0.9785098433494568, + 0.9758208394050598, + 0.9791128635406494, + 0.9835783243179321, + 0.9748319387435913, + 0.9716911315917969, + 0.9732897877693176, + 0.9752651453018188, + 0.9762004613876343, + 0.9760403037071228, + 0.9754682183265686, + 0.9757618308067322, + 0.9773507118225098, + 0.977741003036499, + 0.9767034649848938, + 0.9789731502532959, + 0.9788464903831482, + 0.9802247285842896, + 0.9789708852767944, + 0.9785082936286926, + 0.9784370064735413, + 0.9782934784889221, + 0.9785864949226379, + 0.9796760082244873, + 0.9803999066352844, + 0.9799817204475403, + 0.9791824221611023, + 0.9809419512748718, + 0.9801437258720398, + 0.9804273247718811, + 0.9808862805366516, + 0.9786785244941711, + 0.9803816080093384, + 0.9807310104370117, + 0.9811263680458069, + 0.9809145331382751, + 0.9808802008628845, + 0.9804469347000122, + 0.981417179107666, + 0.9807339310646057, + 0.9795430898666382, + 0.9809128642082214, + 0.982026219367981, + 0.9813324809074402, + 0.9819055795669556, + 0.9553102850914001, + 0.9714292883872986, + 0.9767655730247498, + 0.9774090051651001, + 0.9792118668556213, + 0.9792583584785461, + 0.9822831153869629, + 0.9703272581100464, + 0.9751678705215454, + 0.9770734310150146, + 0.9776707887649536, + 0.9797150492668152, + 0.9790826439857483, + 0.9780145287513733, + 0.9803314805030823, + 0.9812318682670593, + 0.9790366888046265, + 0.9809558987617493, + 0.9810067415237427, + 0.981852650642395, + 0.9823798537254333, + 0.9825533628463745 + ], + "val_loss": [ + 0.28723815083503723, + 0.2713486850261688, + 0.23316213488578796, + 0.21659114956855774, + 0.20152230560779572, + 0.19967979192733765, + 0.19528080523014069, + 0.1881239116191864, + 0.1821192353963852, + 0.17346297204494476, + 0.1740366667509079, + 0.17089757323265076, + 0.16789646446704865, + 0.1660037338733673, + 0.16604383289813995, + 0.17595873773097992, + 0.1708299219608307, + 0.18093723058700562, + 0.16202183067798615, + 0.16549760103225708, + 0.16674426198005676, + 0.1612197458744049, + 0.16146968305110931, + 0.1589117795228958, + 0.16076944768428802, + 0.1649247258901596, + 0.16818247735500336, + 0.17020373046398163, + 0.17034289240837097, + 0.172881081700325, + 0.17018266022205353, + 0.17107126116752625, + 0.17065449059009552, + 0.17375759780406952, + 0.18439385294914246, + 0.28123483061790466, + 0.18936650454998016, + 0.18584774434566498, + 0.191576287150383, + 0.19643668830394745, + 0.6989830136299133, + 0.7043436169624329, + 0.39692726731300354, + 0.21512266993522644, + 0.18475396931171417, + 0.18470022082328796, + 0.1896016001701355, + 0.1995842456817627, + 0.20438623428344727, + 0.2045293003320694, + 0.20099927484989166, + 0.16808313131332397, + 0.1720547080039978, + 0.17455346882343292, + 0.17807002365589142, + 0.1794194132089615, + 0.1817426085472107, + 0.18469466269016266, + 0.18639053404331207, + 0.18814800679683685, + 0.1905999630689621, + 0.19293445348739624, + 0.19491058588027954, + 0.19337712228298187, + 0.19332274794578552, + 0.1962611973285675, + 0.19995097815990448, + 0.20085245370864868, + 0.20314878225326538, + 0.19838808476924896, + 0.19687509536743164, + 0.204046830534935, + 0.20114660263061523, + 0.19768579304218292, + 0.20237891376018524, + 0.2069508284330368, + 0.19951966404914856, + 0.19902941584587097, + 0.20256203413009644, + 0.20888705551624298, + 0.21395546197891235, + 0.2150646299123764, + 0.21527113020420074, + 0.21789665520191193, + 0.21774950623512268, + 0.21902601420879364, + 0.22202743589878082, + 0.22601786255836487, + 0.22464175522327423, + 0.22582165896892548, + 0.21727927029132843, + 0.20705226063728333, + 0.21868649125099182, + 0.22494469583034515, + 0.23110167682170868, + 0.21682409942150116, + 0.22849085927009583, + 0.2313670963048935, + 0.23593372106552124, + 0.2373199164867401, + 0.4533359110355377, + 0.20986557006835938, + 0.19883373379707336, + 0.1990339756011963, + 0.20048026740550995, + 0.2039482146501541, + 0.21253350377082825, + 0.21681621670722961, + 0.22275234758853912, + 0.20872315764427185, + 0.2050584852695465, + 0.21987168490886688, + 0.23898328840732574, + 0.29479342699050903, + 0.3554593622684479, + 0.2502932548522949, + 0.2700035274028778, + 0.33491018414497375, + 0.2824503183364868, + 0.29658064246177673, + 0.31978747248649597, + 0.3138459026813507, + 0.349152535200119, + 0.29236090183258057, + 0.31074059009552, + 0.267335444688797, + 0.281875878572464, + 0.26835721731185913, + 0.2765218913555145, + 0.2490549087524414, + 0.270651638507843, + 0.2674177289009094, + 0.2566213607788086, + 0.25731098651885986, + 0.2570015788078308, + 0.25581344962120056, + 0.2553556263446808, + 0.2602108418941498, + 0.256712406873703, + 0.26190245151519775, + 0.26316896080970764, + 0.2578590214252472, + 0.26275983452796936, + 0.2636365592479706, + 0.26617926359176636, + 0.2679619789123535, + 0.2680908441543579, + 0.26747429370880127, + 0.2679744064807892, + 0.269931823015213, + 0.2670978009700775, + 0.2718101739883423, + 0.27347344160079956, + 0.27368682622909546, + 0.27443116903305054, + 0.2800472378730774, + 0.2824190855026245, + 0.2883802056312561, + 0.2995620369911194, + 0.3005017042160034, + 0.3022603988647461, + 0.30273953080177307, + 0.2985783517360687, + 0.2965547442436218, + 0.2972632646560669, + 0.2994842529296875, + 0.2960124611854553, + 0.2643485963344574, + 0.27351850271224976, + 0.25582224130630493, + 0.2798594832420349, + 0.27532312273979187, + 0.26906818151474, + 0.27401596307754517, + 0.2865924537181854, + 0.26777151226997375, + 0.25301873683929443, + 0.2687220871448517, + 0.33844470977783203, + 0.2465686947107315, + 0.23854368925094604, + 0.23741523921489716, + 0.2444889098405838, + 0.2561025321483612, + 0.2590726613998413, + 0.2630075216293335, + 0.269187867641449, + 0.2795838713645935, + 0.2845974564552307, + 0.28783342242240906, + 0.28345707058906555, + 0.2835446298122406, + 0.28521379828453064, + 0.2847643494606018, + 0.28264346718788147, + 0.2821420133113861, + 0.2802213728427887, + 0.28383877873420715, + 0.3033846616744995, + 0.27275991439819336, + 0.27391302585601807, + 0.2827366888523102, + 0.2799161374568939, + 0.28558021783828735, + 0.2886905074119568, + 0.2805515229701996, + 0.2876611351966858, + 0.2848869860172272, + 0.2906288206577301, + 0.27950969338417053, + 0.254498690366745, + 0.24704022705554962, + 0.2561323940753937, + 0.26700064539909363, + 0.24114003777503967, + 0.2473776936531067, + 0.2354305535554886, + 0.23682336509227753, + 0.21703208982944489, + 0.21673893928527832, + 0.2144777625799179, + 0.19983965158462524, + 0.18079981207847595, + 0.21088142693042755, + 0.19281452894210815, + 0.20026083290576935, + 0.21254859864711761, + 0.22024133801460266, + 0.2252091020345688, + 0.2187010645866394, + 0.2269616574048996, + 0.22800034284591675, + 0.2363997995853424, + 0.23668085038661957, + 0.23937562108039856, + 0.23971028625965118, + 0.24257995188236237, + 0.24286648631095886, + 0.24594220519065857, + 0.24318280816078186, + 0.23669882118701935, + 0.24397698044776917, + 0.24627374112606049, + 0.2619628608226776, + 0.25442594289779663, + 0.2580946385860443, + 0.28162869811058044, + 0.2419835925102234, + 0.24677637219429016, + 0.24248021841049194, + 0.24594305455684662, + 0.25479814410209656, + 0.25116685032844543, + 0.2514200508594513, + 0.2554737627506256, + 0.24422712624073029, + 0.24993255734443665, + 0.24904866516590118, + 0.25754067301750183, + 0.2438945472240448, + 0.2556094229221344, + 0.256924569606781, + 0.24302861094474792, + 0.19861286878585815, + 0.19839613139629364, + 0.2059561163187027, + 0.2107318490743637, + 0.2172657996416092, + 0.22223681211471558, + 0.22834354639053345, + 0.24637369811534882, + 0.22940300405025482, + 0.23817867040634155, + 0.2395147681236267, + 0.24258115887641907, + 0.2501608729362488, + 0.25339969992637634, + 0.21451711654663086, + 0.17311611771583557, + 0.16781897842884064, + 0.1619681566953659, + 0.1635843962430954, + 0.16744934022426605, + 0.17401142418384552, + 0.18347276747226715, + 0.19007667899131775, + 0.1960650086402893, + 0.2002505362033844, + 0.20441405475139618, + 0.20215874910354614, + 0.20224016904830933, + 0.20779606699943542, + 0.21010547876358032, + 0.21457569301128387, + 0.2174983024597168, + 0.2231486290693283, + 0.22089047729969025, + 0.22513629496097565, + 0.22668702900409698, + 0.22853055596351624, + 0.23308581113815308, + 0.23511795699596405, + 0.23914246261119843, + 0.24176160991191864, + 0.2509233355522156, + 0.244533509016037, + 0.2494301050901413, + 0.2527102530002594, + 0.2509970963001251, + 0.23927585780620575, + 0.2555169463157654, + 0.24370647966861725, + 0.24824310839176178, + 0.24955306947231293, + 0.25017765164375305, + 0.24821235239505768, + 0.2448243349790573, + 0.2462622970342636, + 0.24949535727500916, + 0.25326254963874817, + 0.25667038559913635, + 0.25821465253829956, + 0.26118817925453186, + 0.2644621431827545, + 0.2672300338745117, + 0.2668937146663666, + 0.2716364860534668, + 0.2697041928768158, + 0.27032187581062317, + 0.27397218346595764, + 0.2758488059043884, + 0.28225117921829224, + 0.2826656997203827, + 0.3523707389831543, + 0.35616227984428406, + 0.3610374331474304, + 0.34478408098220825, + 0.3169664740562439, + 0.325067400932312, + 0.32318148016929626, + 0.3198719620704651, + 0.31806430220603943, + 0.3153974711894989, + 0.3156879246234894, + 0.31431904435157776, + 0.3098563551902771, + 0.3104737102985382, + 0.31270310282707214, + 0.3181253969669342, + 0.3163810670375824, + 0.3176153600215912, + 0.3160388171672821, + 0.30148062109947205, + 0.39275458455085754, + 0.3649538457393646, + 0.3534722626209259, + 0.3403719663619995, + 0.33192750811576843, + 0.326800674200058, + 0.3222004771232605, + 0.3207232356071472, + 0.32102692127227783, + 0.3194082975387573, + 0.33628493547439575, + 0.3176187574863434, + 0.3192712068557739, + 0.31775015592575073, + 0.31960445642471313, + 0.3243445158004761, + 0.31762775778770447, + 0.3086387515068054, + 0.30278441309928894, + 0.3129657804965973, + 0.3011714220046997, + 0.32860052585601807, + 0.3342500627040863, + 0.3335627615451813, + 0.32909467816352844, + 0.3580184280872345, + 0.3430590331554413, + 0.32331663370132446, + 0.3198379874229431, + 0.31436392664909363, + 0.3113187551498413, + 0.3089807331562042, + 0.30715039372444153, + 0.3089626133441925, + 0.3062823414802551, + 0.30327120423316956, + 0.3016895353794098, + 0.3001919388771057, + 0.30333200097084045, + 0.2917383909225464, + 0.2956298589706421, + 0.29643189907073975, + 0.30130383372306824, + 0.28969016671180725, + 0.299941748380661, + 0.29527848958969116, + 0.30097830295562744, + 0.3061491549015045, + 0.29706084728240967, + 0.31153690814971924, + 0.2997742295265198, + 0.2966189980506897, + 0.29177623987197876, + 0.3229901194572449, + 0.3082697093486786, + 0.30902624130249023, + 0.3128173053264618, + 0.28913018107414246, + 0.29392433166503906, + 0.29433149099349976, + 0.2970201075077057, + 0.24357981979846954, + 0.2757555842399597, + 0.25999417901039124, + 0.2673036456108093, + 0.27445095777511597, + 0.2781122028827667, + 0.2836487889289856, + 0.28633299469947815, + 0.2873220145702362, + 0.2814802825450897, + 0.25818365812301636, + 0.28027209639549255, + 0.2846464514732361, + 0.22045518457889557, + 0.2228807657957077, + 0.22831334173679352, + 0.23165997862815857, + 0.23505640029907227, + 0.23972679674625397, + 0.24111683666706085, + 0.2465180903673172, + 0.25158917903900146, + 0.2559452950954437, + 0.25975653529167175, + 0.2616535425186157, + 0.26229920983314514, + 0.26474112272262573, + 0.2715058922767639, + 0.2710859477519989, + 0.276827335357666, + 0.28341811895370483, + 0.28683561086654663, + 0.2913525402545929, + 0.2918616831302643, + 0.29444780945777893, + 0.29850977659225464, + 0.3032839596271515, + 0.3009886145591736, + 0.30314281582832336, + 0.3034873604774475, + 0.3089466094970703, + 0.3075098693370819, + 0.31456753611564636, + 0.3177722096443176, + 0.31999513506889343, + 0.3137664198875427, + 0.32393699884414673, + 0.31554022431373596, + 0.31903356313705444, + 0.35382118821144104, + 0.3430420756340027, + 0.32786762714385986, + 0.32548537850379944, + 0.32288435101509094, + 0.3248937427997589, + 0.3364331126213074, + 0.32254448533058167, + 0.3199025094509125, + 0.31521788239479065, + 0.3154824376106262, + 0.31081122159957886, + 0.3064108192920685, + 0.29099735617637634, + 0.29672113060951233, + 0.2918401062488556, + 0.29029548168182373, + 0.2933703660964966, + 0.2857241630554199, + 0.28566399216651917, + 0.2848254442214966, + 0.28084006905555725, + 0.27759119868278503, + 0.28184911608695984, + 0.28439953923225403, + 0.2827478349208832, + 0.2888249158859253, + 0.28556081652641296 + ], + "val_dice_coef": [ + 0.7766007781028748, + 0.7362868189811707, + 0.7365289330482483, + 0.736765444278717, + 0.7391619086265564, + 0.7354344725608826, + 0.7399694323539734, + 0.7447896599769592, + 0.7554518580436707, + 0.7736577391624451, + 0.7764828205108643, + 0.7869715094566345, + 0.7908278703689575, + 0.8058087825775146, + 0.8143413662910461, + 0.8216170072555542, + 0.8258571028709412, + 0.8172253966331482, + 0.8246889114379883, + 0.8262502551078796, + 0.8275225162506104, + 0.8288997411727905, + 0.8298549056053162, + 0.8314856886863708, + 0.8321593999862671, + 0.8323621153831482, + 0.833370566368103, + 0.834065854549408, + 0.8344595432281494, + 0.8343856334686279, + 0.8009603023529053, + 0.8194039463996887, + 0.8238095641136169, + 0.8264465928077698, + 0.8254995346069336, + 0.6870414614677429, + 0.8023579716682434, + 0.81545090675354, + 0.8230668306350708, + 0.8244038820266724, + 0.2739965617656708, + 0.2708752453327179, + 0.5725690722465515, + 0.7653201818466187, + 0.8189557194709778, + 0.820512592792511, + 0.821704089641571, + 0.8296846151351929, + 0.8308324217796326, + 0.8331705927848816, + 0.8332914710044861, + 0.8199408054351807, + 0.8255106210708618, + 0.8291324377059937, + 0.8306165933609009, + 0.8324189782142639, + 0.8333595395088196, + 0.8329381346702576, + 0.8340929746627808, + 0.8348809480667114, + 0.8351688385009766, + 0.8362818360328674, + 0.8363726735115051, + 0.8385686278343201, + 0.839104413986206, + 0.8386164903640747, + 0.8377727270126343, + 0.8374584913253784, + 0.8369852304458618, + 0.8386901617050171, + 0.8384276032447815, + 0.8365789651870728, + 0.8369478583335876, + 0.8361538648605347, + 0.8366899490356445, + 0.836506187915802, + 0.8378056287765503, + 0.8368849158287048, + 0.8375480771064758, + 0.8367794156074524, + 0.8366787433624268, + 0.8369461894035339, + 0.8369993567466736, + 0.8365828394889832, + 0.836992621421814, + 0.8358801007270813, + 0.8362458944320679, + 0.8364722728729248, + 0.8367875218391418, + 0.8361663222312927, + 0.837709367275238, + 0.8383733034133911, + 0.837929904460907, + 0.8377741575241089, + 0.8372476100921631, + 0.8400737047195435, + 0.8390684127807617, + 0.83942711353302, + 0.8388166427612305, + 0.8388614654541016, + 0.5899927616119385, + 0.8362930417060852, + 0.8417453169822693, + 0.8424978852272034, + 0.8421810865402222, + 0.8427466154098511, + 0.8386762142181396, + 0.8371509909629822, + 0.833957314491272, + 0.8364364504814148, + 0.8421381115913391, + 0.8345608115196228, + 0.8007285594940186, + 0.7260335087776184, + 0.6692654490470886, + 0.7947330474853516, + 0.7512447237968445, + 0.6925077438354492, + 0.745521605014801, + 0.7343494296073914, + 0.7150691151618958, + 0.7198730111122131, + 0.6902240514755249, + 0.7437922954559326, + 0.7274830937385559, + 0.7680156826972961, + 0.7545132040977478, + 0.768580436706543, + 0.7604827284812927, + 0.7898576855659485, + 0.7694604992866516, + 0.7757206559181213, + 0.7929953932762146, + 0.7918580174446106, + 0.7945917248725891, + 0.803205132484436, + 0.8119467496871948, + 0.8164100646972656, + 0.8216281533241272, + 0.8251609802246094, + 0.8262266516685486, + 0.8401311039924622, + 0.8315710425376892, + 0.8277207016944885, + 0.8256680369377136, + 0.8304272890090942, + 0.8297643065452576, + 0.8305029273033142, + 0.833788275718689, + 0.8272045850753784, + 0.8259381651878357, + 0.8239844441413879, + 0.8249673843383789, + 0.823042094707489, + 0.8268076777458191, + 0.8339473605155945, + 0.8356831073760986, + 0.8316064476966858, + 0.8334366083145142, + 0.8356269598007202, + 0.8349707722663879, + 0.8357625603675842, + 0.836354672908783, + 0.8350753784179688, + 0.8341637253761292, + 0.835790753364563, + 0.8393711447715759, + 0.8449829816818237, + 0.8419803977012634, + 0.8453216552734375, + 0.8424800634384155, + 0.8430929780006409, + 0.8414286375045776, + 0.825545072555542, + 0.8366214036941528, + 0.8449114561080933, + 0.8470876216888428, + 0.8456646203994751, + 0.8375294208526611, + 0.8552544116973877, + 0.8545020222663879, + 0.8554497361183167, + 0.8545476794242859, + 0.8512993454933167, + 0.8510252237319946, + 0.8507089018821716, + 0.8489977121353149, + 0.8457096219062805, + 0.8413915634155273, + 0.8370864987373352, + 0.8341518044471741, + 0.8336361646652222, + 0.8339350819587708, + 0.8373669385910034, + 0.8392994403839111, + 0.839483380317688, + 0.8332099914550781, + 0.8350664973258972, + 0.8205387592315674, + 0.8416454195976257, + 0.8428033590316772, + 0.8411937952041626, + 0.8427980542182922, + 0.8404627442359924, + 0.842330276966095, + 0.8395564556121826, + 0.8414638638496399, + 0.8424361944198608, + 0.8446895480155945, + 0.846848726272583, + 0.8489895462989807, + 0.8482503294944763, + 0.8459113240242004, + 0.844142735004425, + 0.8484106659889221, + 0.8471189737319946, + 0.845054566860199, + 0.8444048166275024, + 0.8466100096702576, + 0.8477064967155457, + 0.8483533263206482, + 0.8503214716911316, + 0.8509876728057861, + 0.8316774964332581, + 0.8370324373245239, + 0.839336097240448, + 0.8408005237579346, + 0.8428125977516174, + 0.8437373638153076, + 0.8455855846405029, + 0.8456382155418396, + 0.8462637662887573, + 0.8460706472396851, + 0.8462703227996826, + 0.8461295366287231, + 0.8465862274169922, + 0.8466053605079651, + 0.846900999546051, + 0.8466494083404541, + 0.845362663269043, + 0.8469746112823486, + 0.846789538860321, + 0.8473795056343079, + 0.8461935520172119, + 0.8466159105300903, + 0.8463333249092102, + 0.8458000421524048, + 0.8505790829658508, + 0.8503738641738892, + 0.8476802110671997, + 0.845830500125885, + 0.8487781882286072, + 0.8485366106033325, + 0.8473630547523499, + 0.8498547077178955, + 0.8499366641044617, + 0.8500738739967346, + 0.8502128720283508, + 0.8495468497276306, + 0.8469513654708862, + 0.8481175899505615, + 0.849429726600647, + 0.849066972732544, + 0.8548311591148376, + 0.8550606369972229, + 0.8563703894615173, + 0.8567954897880554, + 0.856242835521698, + 0.8552171587944031, + 0.8527343273162842, + 0.8532533049583435, + 0.8556985259056091, + 0.8543787598609924, + 0.8551021218299866, + 0.8552600741386414, + 0.8515934348106384, + 0.8519724011421204, + 0.8135335445404053, + 0.8389540910720825, + 0.8461066484451294, + 0.8526746034622192, + 0.8554375171661377, + 0.8559949994087219, + 0.8570294976234436, + 0.857523500919342, + 0.8580809831619263, + 0.8586397171020508, + 0.8586057424545288, + 0.858701765537262, + 0.8587692975997925, + 0.859142541885376, + 0.8580973148345947, + 0.8573670983314514, + 0.8562242388725281, + 0.8564314246177673, + 0.8561164736747742, + 0.8571146130561829, + 0.8570607304573059, + 0.8561781048774719, + 0.8559383749961853, + 0.8557175397872925, + 0.8561850786209106, + 0.8556377291679382, + 0.8554741740226746, + 0.8555220365524292, + 0.8559049963951111, + 0.8547510504722595, + 0.8522539734840393, + 0.8539879322052002, + 0.852405846118927, + 0.849246621131897, + 0.8530959486961365, + 0.8528076410293579, + 0.852859616279602, + 0.8532447814941406, + 0.8533909916877747, + 0.8527594804763794, + 0.8533093333244324, + 0.8533316850662231, + 0.853492796421051, + 0.8531711101531982, + 0.8533493280410767, + 0.8533123135566711, + 0.8528207540512085, + 0.852786660194397, + 0.8533142805099487, + 0.8515560626983643, + 0.8524121046066284, + 0.8523905277252197, + 0.8523167371749878, + 0.8519017100334167, + 0.8507035970687866, + 0.8511326909065247, + 0.839489221572876, + 0.8289598822593689, + 0.8272483348846436, + 0.8299774527549744, + 0.8335668444633484, + 0.8329368829727173, + 0.8338814377784729, + 0.8352221846580505, + 0.8364477753639221, + 0.8375290036201477, + 0.8377435803413391, + 0.8386745452880859, + 0.8395848870277405, + 0.8398758769035339, + 0.8401604890823364, + 0.8399484753608704, + 0.8398151397705078, + 0.8405882716178894, + 0.8409283757209778, + 0.8175678253173828, + 0.8111531138420105, + 0.822675347328186, + 0.8253350853919983, + 0.8251060843467712, + 0.8288128972053528, + 0.8320625424385071, + 0.8333511352539062, + 0.8346990346908569, + 0.8357964754104614, + 0.8368164300918579, + 0.8353255391120911, + 0.8381621241569519, + 0.838522732257843, + 0.8384892344474792, + 0.8390708565711975, + 0.838234543800354, + 0.8389724493026733, + 0.8399551510810852, + 0.8399912714958191, + 0.8401727080345154, + 0.8411318063735962, + 0.8280410766601562, + 0.8258520364761353, + 0.8292118906974792, + 0.8323183059692383, + 0.8278461694717407, + 0.8286098837852478, + 0.8372151255607605, + 0.8382918834686279, + 0.8395122289657593, + 0.840423047542572, + 0.8408432006835938, + 0.841245710849762, + 0.8415670990943909, + 0.8418585658073425, + 0.8411929607391357, + 0.8421667218208313, + 0.8424487709999084, + 0.8429855704307556, + 0.8435900211334229, + 0.8436681032180786, + 0.8439942002296448, + 0.8434712290763855, + 0.8450051546096802, + 0.8439319133758545, + 0.8445190787315369, + 0.8430773615837097, + 0.8428373336791992, + 0.8436524271965027, + 0.8429920673370361, + 0.8443013429641724, + 0.8443319797515869, + 0.8449892401695251, + 0.8376885056495667, + 0.841406524181366, + 0.8415302038192749, + 0.8408886194229126, + 0.8442723155021667, + 0.8440332412719727, + 0.8442908525466919, + 0.8443634510040283, + 0.8437660336494446, + 0.8392053842544556, + 0.8405532836914062, + 0.8407477736473083, + 0.8406219482421875, + 0.8408179879188538, + 0.8407355546951294, + 0.8411082625389099, + 0.8407227396965027, + 0.8417288661003113, + 0.8435289859771729, + 0.8406943082809448, + 0.8406220078468323, + 0.8430836200714111, + 0.8414878249168396, + 0.8407925367355347, + 0.8421801328659058, + 0.8422234654426575, + 0.8420581221580505, + 0.8420503735542297, + 0.8411982655525208, + 0.8417600989341736, + 0.8418623805046082, + 0.8419300317764282, + 0.8423025608062744, + 0.8427544236183167, + 0.8427135348320007, + 0.844169020652771, + 0.8436357378959656, + 0.8427347540855408, + 0.8422463536262512, + 0.8415619134902954, + 0.8414194583892822, + 0.8415847420692444, + 0.8405125737190247, + 0.8403295874595642, + 0.8391857147216797, + 0.839878261089325, + 0.8403366208076477, + 0.8406816124916077, + 0.8406360149383545, + 0.8400084376335144, + 0.8400857448577881, + 0.839959979057312, + 0.8401797413825989, + 0.8409728407859802, + 0.8402512669563293, + 0.8409667611122131, + 0.841234564781189, + 0.8381590247154236, + 0.8394162058830261, + 0.8405060172080994, + 0.8413972854614258, + 0.8416813015937805, + 0.8416779041290283, + 0.8404427170753479, + 0.8398978114128113, + 0.8398051857948303, + 0.8407351970672607, + 0.8412854671478271, + 0.8423576951026917, + 0.8464295268058777, + 0.8451360464096069, + 0.8451173305511475, + 0.8459922671318054, + 0.8468573093414307, + 0.846176266670227, + 0.8468344211578369, + 0.8463738560676575, + 0.8468202352523804, + 0.846664547920227, + 0.8477717638015747, + 0.8471964001655579, + 0.8456940054893494, + 0.8466700315475464, + 0.8462275266647339, + 0.8469456434249878 + ], + "val_precision": [ + 0.8711534738540649, + 0.9099778532981873, + 0.9012740850448608, + 0.8973432183265686, + 0.8958495259284973, + 0.8980125188827515, + 0.8952105045318604, + 0.8935657143592834, + 0.8896180391311646, + 0.8830705285072327, + 0.881388783454895, + 0.8785831332206726, + 0.8784195780754089, + 0.8662593960762024, + 0.8595232367515564, + 0.8465721011161804, + 0.8446353673934937, + 0.8311983942985535, + 0.8418576717376709, + 0.8404845595359802, + 0.8405564427375793, + 0.8373354077339172, + 0.8398115634918213, + 0.8381937742233276, + 0.8403121829032898, + 0.8403711915016174, + 0.8373908400535583, + 0.8363755941390991, + 0.8387280702590942, + 0.8359437584877014, + 0.8607511520385742, + 0.8508727550506592, + 0.8516408205032349, + 0.850869357585907, + 0.8463905453681946, + 0.9256880879402161, + 0.8811120986938477, + 0.8642243146896362, + 0.8511949777603149, + 0.8438611626625061, + 0.897867739200592, + 0.8976171612739563, + 0.9503747820854187, + 0.9013785123825073, + 0.8685796856880188, + 0.8691869378089905, + 0.8672569990158081, + 0.8460566997528076, + 0.8413045406341553, + 0.8369789123535156, + 0.8355416059494019, + 0.8512107133865356, + 0.846632182598114, + 0.8438904285430908, + 0.8412598967552185, + 0.8414592742919922, + 0.8418748378753662, + 0.8401152491569519, + 0.8390417695045471, + 0.8385841846466064, + 0.839232325553894, + 0.8377540707588196, + 0.834185779094696, + 0.8351815342903137, + 0.83882075548172, + 0.8391847014427185, + 0.8374903798103333, + 0.8391371369361877, + 0.8375192284584045, + 0.8376700282096863, + 0.8398131728172302, + 0.8305045366287231, + 0.8372330069541931, + 0.8399054408073425, + 0.8377699255943298, + 0.8375679850578308, + 0.835006594657898, + 0.8356190919876099, + 0.8395451903343201, + 0.8383533954620361, + 0.838898777961731, + 0.8382733464241028, + 0.8402997255325317, + 0.8408578038215637, + 0.8415656685829163, + 0.8414214253425598, + 0.840507447719574, + 0.8395826816558838, + 0.8433931469917297, + 0.8396790027618408, + 0.8400417566299438, + 0.8480983972549438, + 0.8443701267242432, + 0.8446289896965027, + 0.8388392925262451, + 0.8489418029785156, + 0.8449345827102661, + 0.8443923592567444, + 0.845302402973175, + 0.8450705409049988, + 0.9710460305213928, + 0.8663370609283447, + 0.8593899011611938, + 0.8566167950630188, + 0.859527051448822, + 0.8560407757759094, + 0.8620664477348328, + 0.8610326647758484, + 0.8568584322929382, + 0.8632041811943054, + 0.8479426503181458, + 0.8565797805786133, + 0.881535530090332, + 0.9172544479370117, + 0.9386345148086548, + 0.8827859163284302, + 0.9101639986038208, + 0.9348036646842957, + 0.9130666255950928, + 0.9183593392372131, + 0.9276469349861145, + 0.9261051416397095, + 0.9390121102333069, + 0.917735755443573, + 0.9261561036109924, + 0.9070261716842651, + 0.9143952131271362, + 0.9071436524391174, + 0.9126418232917786, + 0.8976956009864807, + 0.9085080623626709, + 0.905343770980835, + 0.8929307460784912, + 0.8945400714874268, + 0.8930835723876953, + 0.884345293045044, + 0.8809970617294312, + 0.8748594522476196, + 0.873525857925415, + 0.8670804500579834, + 0.8665019869804382, + 0.8446500301361084, + 0.8590936064720154, + 0.8657666444778442, + 0.8678518533706665, + 0.8619884848594666, + 0.8635320663452148, + 0.8665153384208679, + 0.8587769269943237, + 0.865226149559021, + 0.8698509931564331, + 0.871242344379425, + 0.8705009818077087, + 0.8737422823905945, + 0.8686668872833252, + 0.8597418069839478, + 0.8551826477050781, + 0.8492261171340942, + 0.8479729294776917, + 0.8436432480812073, + 0.8426222205162048, + 0.8409585952758789, + 0.8462572693824768, + 0.8505966067314148, + 0.8527435660362244, + 0.8455612063407898, + 0.8418802618980408, + 0.8454927802085876, + 0.8337236046791077, + 0.8402150869369507, + 0.8331714868545532, + 0.8320943713188171, + 0.8520482778549194, + 0.8709573745727539, + 0.8536549210548401, + 0.8313764333724976, + 0.8369816541671753, + 0.8376103043556213, + 0.8010546565055847, + 0.8396194577217102, + 0.8443663716316223, + 0.8486455082893372, + 0.8437392711639404, + 0.8403691053390503, + 0.8407142758369446, + 0.8399880528450012, + 0.839185893535614, + 0.8460862636566162, + 0.8545673489570618, + 0.860598623752594, + 0.8658726811408997, + 0.8679269552230835, + 0.866565465927124, + 0.8626867532730103, + 0.861561119556427, + 0.8615671396255493, + 0.8700774908065796, + 0.8644501566886902, + 0.8783872723579407, + 0.8526102900505066, + 0.8481480479240417, + 0.8520728349685669, + 0.850262463092804, + 0.8534269332885742, + 0.850530743598938, + 0.8601478934288025, + 0.8564969301223755, + 0.8558398485183716, + 0.8488183617591858, + 0.8448949456214905, + 0.8506718873977661, + 0.8446462154388428, + 0.8446815013885498, + 0.8393603563308716, + 0.8363129496574402, + 0.8319042921066284, + 0.8438207507133484, + 0.8403454422950745, + 0.8494036197662354, + 0.8482406139373779, + 0.8493703603744507, + 0.8521566390991211, + 0.8635560870170593, + 0.835563600063324, + 0.8510769009590149, + 0.8496324419975281, + 0.8478512167930603, + 0.847159206867218, + 0.8458375930786133, + 0.8478499054908752, + 0.8420311212539673, + 0.8471897840499878, + 0.8455916047096252, + 0.8447195887565613, + 0.8431907296180725, + 0.8468424081802368, + 0.8455039858818054, + 0.8472043871879578, + 0.8467671871185303, + 0.8406612873077393, + 0.8456061482429504, + 0.8399146199226379, + 0.8432310819625854, + 0.839999258518219, + 0.8430536985397339, + 0.841952919960022, + 0.8331800699234009, + 0.8475290536880493, + 0.8477332592010498, + 0.8592418432235718, + 0.8608124256134033, + 0.8471086025238037, + 0.8507931232452393, + 0.8570448160171509, + 0.8441699147224426, + 0.8461606502532959, + 0.8479183912277222, + 0.8481042981147766, + 0.8468204140663147, + 0.8429912328720093, + 0.8440841436386108, + 0.8464818596839905, + 0.8488946557044983, + 0.8782042264938354, + 0.8803180456161499, + 0.8736661672592163, + 0.8709073662757874, + 0.8666850328445435, + 0.8635985255241394, + 0.8619860410690308, + 0.8458244204521179, + 0.8539053797721863, + 0.8490270376205444, + 0.8503031134605408, + 0.8508244156837463, + 0.8424339890480042, + 0.8424623608589172, + 0.9407716393470764, + 0.9200730919837952, + 0.912413477897644, + 0.901079535484314, + 0.8914381861686707, + 0.8887494802474976, + 0.8808781504631042, + 0.8749341368675232, + 0.8733342289924622, + 0.8700673580169678, + 0.8675072193145752, + 0.8664731383323669, + 0.8562585115432739, + 0.8603118658065796, + 0.8580091595649719, + 0.8594563603401184, + 0.8582649230957031, + 0.8568836450576782, + 0.8514372706413269, + 0.8594683408737183, + 0.8582966923713684, + 0.8588802218437195, + 0.8594077825546265, + 0.856478750705719, + 0.8584680557250977, + 0.8564247488975525, + 0.8565948605537415, + 0.8511720895767212, + 0.8616523742675781, + 0.8520718216896057, + 0.8509661555290222, + 0.8458645939826965, + 0.8619470596313477, + 0.8367296457290649, + 0.8450754284858704, + 0.8439561724662781, + 0.8469855189323425, + 0.8472172617912292, + 0.8471127152442932, + 0.8471678495407104, + 0.8489171862602234, + 0.8484451770782471, + 0.8490914702415466, + 0.849514901638031, + 0.8497565984725952, + 0.848747193813324, + 0.847416341304779, + 0.8483163118362427, + 0.8474992513656616, + 0.8459259867668152, + 0.8463056683540344, + 0.8468024134635925, + 0.8468325138092041, + 0.8456897735595703, + 0.8450791239738464, + 0.8440936803817749, + 0.830335795879364, + 0.7993900179862976, + 0.793070912361145, + 0.8034380078315735, + 0.8194118738174438, + 0.8142642378807068, + 0.8157716393470764, + 0.8192340135574341, + 0.8195520043373108, + 0.8228065371513367, + 0.8221330642700195, + 0.8225405812263489, + 0.8252206444740295, + 0.8260633945465088, + 0.825140118598938, + 0.8256429433822632, + 0.8282857537269592, + 0.8261280655860901, + 0.8279370665550232, + 0.8756867051124573, + 0.8302189111709595, + 0.8211290836334229, + 0.8243001103401184, + 0.8352564573287964, + 0.8340696096420288, + 0.8324414491653442, + 0.8342196345329285, + 0.8347622752189636, + 0.8345479965209961, + 0.8326995968818665, + 0.8201993107795715, + 0.8334702849388123, + 0.8323670625686646, + 0.8328129649162292, + 0.8324663639068604, + 0.8275765776634216, + 0.8328525424003601, + 0.8338319063186646, + 0.8406358957290649, + 0.8342118859291077, + 0.8341279625892639, + 0.844491720199585, + 0.8484975099563599, + 0.8435754179954529, + 0.8398926854133606, + 0.8345503211021423, + 0.839335560798645, + 0.8277166485786438, + 0.8274762630462646, + 0.8282976746559143, + 0.8294353485107422, + 0.8292824625968933, + 0.8328487873077393, + 0.8292620778083801, + 0.8316678404808044, + 0.836163341999054, + 0.8351377844810486, + 0.8346888422966003, + 0.8311458230018616, + 0.838955283164978, + 0.8381076455116272, + 0.8383089900016785, + 0.8354285359382629, + 0.838118851184845, + 0.8365051746368408, + 0.8356513977050781, + 0.8323842287063599, + 0.8316612243652344, + 0.8370093107223511, + 0.8296104073524475, + 0.8354123830795288, + 0.8300120234489441, + 0.8340489864349365, + 0.8298304080963135, + 0.8299878239631653, + 0.8266295790672302, + 0.8280067443847656, + 0.8353079557418823, + 0.8343426585197449, + 0.8352524638175964, + 0.835371196269989, + 0.8479501605033875, + 0.8313549160957336, + 0.8351329565048218, + 0.836721658706665, + 0.8332098126411438, + 0.8351552486419678, + 0.83438640832901, + 0.8342813849449158, + 0.8330809473991394, + 0.8357232809066772, + 0.8361101746559143, + 0.8338901996612549, + 0.8371801376342773, + 0.8381467461585999, + 0.8457373380661011, + 0.8443560004234314, + 0.8427362442016602, + 0.843136191368103, + 0.8403492569923401, + 0.8423163890838623, + 0.8420248627662659, + 0.840101957321167, + 0.8397964835166931, + 0.8396996259689331, + 0.8402324914932251, + 0.8376874923706055, + 0.8397992849349976, + 0.8312510848045349, + 0.8357895016670227, + 0.8348694443702698, + 0.8337839841842651, + 0.8347054123878479, + 0.832122802734375, + 0.8326663374900818, + 0.8303632140159607, + 0.8294076919555664, + 0.8289780020713806, + 0.8320909142494202, + 0.8328109383583069, + 0.8335252404212952, + 0.832923173904419, + 0.8333304524421692, + 0.8312778472900391, + 0.8306232690811157, + 0.8294838666915894, + 0.8314119577407837, + 0.8328531980514526, + 0.836654543876648, + 0.8344799876213074, + 0.8263624906539917, + 0.8262391686439514, + 0.8288564085960388, + 0.8312889337539673, + 0.832801103591919, + 0.8319774270057678, + 0.8286105990409851, + 0.8272514939308167, + 0.8288088440895081, + 0.8281652927398682, + 0.828852653503418, + 0.830596387386322, + 0.8328118920326233, + 0.8277364373207092, + 0.8315495848655701, + 0.8332411050796509, + 0.8352165222167969, + 0.8314898014068604, + 0.8340672850608826, + 0.8345930576324463, + 0.835718035697937, + 0.837015688419342, + 0.8390412926673889, + 0.8377916216850281, + 0.8331015110015869, + 0.8361600041389465, + 0.8335365653038025, + 0.8368273377418518 + ], + "val_recall": [ + 0.7138949036598206, + 0.6286648511886597, + 0.6385065317153931, + 0.6445983648300171, + 0.6512171030044556, + 0.6463722586631775, + 0.6573383212089539, + 0.666047215461731, + 0.6832208037376404, + 0.7122219800949097, + 0.7173653841018677, + 0.7347973585128784, + 0.7411508560180664, + 0.7763452529907227, + 0.7962082624435425, + 0.8206493258476257, + 0.8289700150489807, + 0.8276956081390381, + 0.8290759325027466, + 0.8334759473800659, + 0.8357791900634766, + 0.8409918546676636, + 0.8399480581283569, + 0.844188928604126, + 0.8431087136268616, + 0.843708872795105, + 0.848979651927948, + 0.851351797580719, + 0.8493645191192627, + 0.8524963855743408, + 0.7727329730987549, + 0.8133203983306885, + 0.8196693062782288, + 0.8249475955963135, + 0.8286963701248169, + 0.5666696429252625, + 0.7584494352340698, + 0.7947556972503662, + 0.8203129172325134, + 0.8304421901702881, + 0.1745501309633255, + 0.17224669456481934, + 0.42887505888938904, + 0.6848285794258118, + 0.7948252558708191, + 0.7959211468696594, + 0.799809992313385, + 0.8348971009254456, + 0.842459499835968, + 0.8518373966217041, + 0.8531920313835144, + 0.8106057047843933, + 0.8251631259918213, + 0.8344990015029907, + 0.8401424884796143, + 0.8430490493774414, + 0.844461977481842, + 0.8462569117546082, + 0.8492566347122192, + 0.8510982394218445, + 0.851049542427063, + 0.8556559681892395, + 0.8597413897514343, + 0.8624208569526672, + 0.8590958118438721, + 0.857639491558075, + 0.8579022884368896, + 0.8555306196212769, + 0.8565317988395691, + 0.8590701818466187, + 0.8552781939506531, + 0.863028347492218, + 0.8566315174102783, + 0.8506564497947693, + 0.8543311953544617, + 0.8545897006988525, + 0.8585560321807861, + 0.8554390072822571, + 0.853297233581543, + 0.853861927986145, + 0.8533856272697449, + 0.8548802733421326, + 0.8526723384857178, + 0.8513028621673584, + 0.8510808348655701, + 0.8495150804519653, + 0.8507934808731079, + 0.8524266481399536, + 0.8483285307884216, + 0.8525064587593079, + 0.8531935811042786, + 0.8452684879302979, + 0.8489734530448914, + 0.8487175703048706, + 0.8545060753822327, + 0.8466358780860901, + 0.8503495454788208, + 0.8516244292259216, + 0.8498554229736328, + 0.850494384765625, + 0.44720861315727234, + 0.8235508799552917, + 0.8393977284431458, + 0.8446410894393921, + 0.8418731093406677, + 0.8467342257499695, + 0.8340064883232117, + 0.832314670085907, + 0.8308895826339722, + 0.8281630873680115, + 0.854369044303894, + 0.8318706750869751, + 0.7502104043960571, + 0.6147958040237427, + 0.5343802571296692, + 0.7403198480606079, + 0.6522589325904846, + 0.5624599456787109, + 0.6424561738967896, + 0.6245296597480774, + 0.5936372876167297, + 0.60028475522995, + 0.557341992855072, + 0.6365642547607422, + 0.6096900701522827, + 0.6775719523429871, + 0.6531724333763123, + 0.6783187985420227, + 0.6626167297363281, + 0.7168574929237366, + 0.6787083745002747, + 0.6903501749038696, + 0.7266693711280823, + 0.7233291864395142, + 0.7288182973861694, + 0.7511191964149475, + 0.7684029340744019, + 0.7821173071861267, + 0.7917649745941162, + 0.8049054741859436, + 0.8070986270904541, + 0.8550939559936523, + 0.8245589733123779, + 0.8108031749725342, + 0.805087149143219, + 0.8195995688438416, + 0.8167506456375122, + 0.8159922957420349, + 0.8290852904319763, + 0.8105812072753906, + 0.8037106394767761, + 0.7992058992385864, + 0.8014134168624878, + 0.7948223352432251, + 0.8062489628791809, + 0.828700065612793, + 0.8370940089225769, + 0.8341236114501953, + 0.8411271572113037, + 0.8495692610740662, + 0.8506975769996643, + 0.8540267944335938, + 0.848400354385376, + 0.8412774205207825, + 0.8369691371917725, + 0.8473140597343445, + 0.8577393889427185, + 0.8619421124458313, + 0.8698633313179016, + 0.8683803677558899, + 0.8725437521934509, + 0.8747128844261169, + 0.8498877882957458, + 0.8027667999267578, + 0.8402888774871826, + 0.8779536485671997, + 0.8747881054878235, + 0.8732026219367981, + 0.9028149843215942, + 0.888365626335144, + 0.882239043712616, + 0.8798195719718933, + 0.8843492269515991, + 0.8823739886283875, + 0.8820407390594482, + 0.8825317621231079, + 0.8807196021080017, + 0.8673355579376221, + 0.8505632877349854, + 0.8365170955657959, + 0.8258256316184998, + 0.8226954340934753, + 0.8244332075119019, + 0.8335296511650085, + 0.8384708166122437, + 0.837757408618927, + 0.8183615803718567, + 0.8272135257720947, + 0.7957266569137573, + 0.8523391485214233, + 0.8591209053993225, + 0.8521117568016052, + 0.8563898205757141, + 0.8488904237747192, + 0.8548893332481384, + 0.8391090631484985, + 0.8465986847877502, + 0.8486119508743286, + 0.860098659992218, + 0.8688785433769226, + 0.8623787760734558, + 0.8664694428443909, + 0.864578366279602, + 0.8683936595916748, + 0.8762837648391724, + 0.8789510130882263, + 0.8645862340927124, + 0.8669252991676331, + 0.8612207174301147, + 0.8644996285438538, + 0.8641279339790344, + 0.8632484674453735, + 0.850382387638092, + 0.84476238489151, + 0.838936984539032, + 0.8458986878395081, + 0.8510852456092834, + 0.8555907011032104, + 0.858672022819519, + 0.8590560555458069, + 0.8654761910438538, + 0.8610984683036804, + 0.8630610108375549, + 0.8639771938323975, + 0.8652169108390808, + 0.8618441224098206, + 0.8632146716117859, + 0.861068069934845, + 0.8609952330589294, + 0.8652485609054565, + 0.862664520740509, + 0.86861652135849, + 0.8662468194961548, + 0.8682829737663269, + 0.8652265667915344, + 0.8663906455039978, + 0.8764907121658325, + 0.8683980703353882, + 0.8681570887565613, + 0.8505733609199524, + 0.8458050489425659, + 0.8658043742179871, + 0.8610346913337708, + 0.8521830439567566, + 0.8704118728637695, + 0.8671225905418396, + 0.866253674030304, + 0.8651615381240845, + 0.8661726713180542, + 0.8627281188964844, + 0.8668228387832642, + 0.8665737509727478, + 0.8614661693572998, + 0.8398611545562744, + 0.8382641077041626, + 0.8480514287948608, + 0.8521321415901184, + 0.8562690019607544, + 0.8573816418647766, + 0.8554061055183411, + 0.8768815398216248, + 0.8713105320930481, + 0.8744873404502869, + 0.8742911219596863, + 0.8740506768226624, + 0.8786142468452454, + 0.8788845539093018, + 0.7281445264816284, + 0.7792198657989502, + 0.7952057123184204, + 0.8151026964187622, + 0.8281133770942688, + 0.8318414092063904, + 0.8414968848228455, + 0.8494554758071899, + 0.8523086905479431, + 0.8567559123039246, + 0.8594571948051453, + 0.8608883619308472, + 0.8712697625160217, + 0.8679503202438354, + 0.8685165047645569, + 0.8658415675163269, + 0.8650411367416382, + 0.8667349219322205, + 0.8720259070396423, + 0.8650246858596802, + 0.8663941025733948, + 0.8645951151847839, + 0.8638908863067627, + 0.8668110966682434, + 0.865586519241333, + 0.866898775100708, + 0.8665620684623718, + 0.8728160262107849, + 0.8631929755210876, + 0.8721675872802734, + 0.8678584694862366, + 0.8771296143531799, + 0.8559105396270752, + 0.878553569316864, + 0.8760054707527161, + 0.8771207928657532, + 0.8741464614868164, + 0.8744163513183594, + 0.8740748763084412, + 0.8720303177833557, + 0.8714028596878052, + 0.8722689747810364, + 0.872020959854126, + 0.8710580468177795, + 0.8712208271026611, + 0.8726319670677185, + 0.8732669353485107, + 0.872332751750946, + 0.8741260170936584, + 0.8732951879501343, + 0.8740338683128357, + 0.8736048340797424, + 0.8737559914588928, + 0.8744267821311951, + 0.8729435205459595, + 0.8749172687530518, + 0.8717195987701416, + 0.8865574598312378, + 0.89252108335495, + 0.8859567642211914, + 0.8748524785041809, + 0.8796765804290771, + 0.8795775175094604, + 0.8781686425209045, + 0.8796904683113098, + 0.8776493072509766, + 0.8786180019378662, + 0.8797085285186768, + 0.8780279159545898, + 0.8777487874031067, + 0.8793208003044128, + 0.8785813450813293, + 0.8751891851425171, + 0.8794978260993958, + 0.8776888251304626, + 0.7911058068275452, + 0.8311010003089905, + 0.8563839793205261, + 0.8567081689834595, + 0.8443841934204102, + 0.8506578803062439, + 0.8573545217514038, + 0.8574348092079163, + 0.8590723276138306, + 0.8612195253372192, + 0.86473548412323, + 0.8760099411010742, + 0.8660531640052795, + 0.8679808378219604, + 0.8673582673072815, + 0.8687211871147156, + 0.871544599533081, + 0.8671865463256836, + 0.8673295378684998, + 0.8607675433158875, + 0.8682910799980164, + 0.8698164820671082, + 0.8410481810569763, + 0.8346937298774719, + 0.845123291015625, + 0.8533225655555725, + 0.8528366088867188, + 0.8489379286766052, + 0.8733007311820984, + 0.8751024007797241, + 0.8758694529533386, + 0.8759905695915222, + 0.8767417669296265, + 0.8735772371292114, + 0.8776817917823792, + 0.8754622340202332, + 0.8693927526473999, + 0.8716742396354675, + 0.8722932934761047, + 0.8772057890892029, + 0.8697801232337952, + 0.8707911968231201, + 0.8711106777191162, + 0.8733454942703247, + 0.872536838054657, + 0.872806191444397, + 0.8746081590652466, + 0.8756340742111206, + 0.8766515254974365, + 0.8718135356903076, + 0.8793970346450806, + 0.8753090500831604, + 0.8807417750358582, + 0.8770509958267212, + 0.8700686693191528, + 0.8759216666221619, + 0.879711925983429, + 0.8776720762252808, + 0.8750447034835815, + 0.8758551478385925, + 0.8751562833786011, + 0.8752334713935852, + 0.8557966351509094, + 0.8664177060127258, + 0.8659477829933167, + 0.8655065298080444, + 0.8695750832557678, + 0.8680756688117981, + 0.8689611554145813, + 0.8700358867645264, + 0.8705236315727234, + 0.8693567514419556, + 0.8715224862098694, + 0.8695716857910156, + 0.8662028908729553, + 0.8606469631195068, + 0.8523414134979248, + 0.8541622757911682, + 0.8591538667678833, + 0.8598201870918274, + 0.8630911707878113, + 0.8613525629043579, + 0.8603708148002625, + 0.8637220859527588, + 0.8644904494285583, + 0.8649963736534119, + 0.8652148842811584, + 0.868453860282898, + 0.8662175536155701, + 0.8789054751396179, + 0.8731442093849182, + 0.8728610277175903, + 0.8733997344970703, + 0.871367335319519, + 0.8739941716194153, + 0.873519241809845, + 0.8734442591667175, + 0.8747082948684692, + 0.8735887408256531, + 0.8713658452033997, + 0.8713515400886536, + 0.8709768652915955, + 0.8718175888061523, + 0.8706104159355164, + 0.8728729486465454, + 0.8735764622688293, + 0.8750870227813721, + 0.8737728595733643, + 0.8718305230140686, + 0.8687755465507507, + 0.8713995218276978, + 0.8761162161827087, + 0.8786718249320984, + 0.8769229650497437, + 0.8760003447532654, + 0.8746944665908813, + 0.8755800127983093, + 0.8767211437225342, + 0.8771851062774658, + 0.8749100565910339, + 0.8772239685058594, + 0.8775946497917175, + 0.8773806095123291, + 0.8823702931404114, + 0.8854650855064392, + 0.8822610378265381, + 0.8820188045501709, + 0.881903886795044, + 0.8842942118644714, + 0.8824272751808167, + 0.8809127807617188, + 0.8803654313087463, + 0.87862229347229, + 0.8781221508979797, + 0.8786019682884216, + 0.8810625076293945, + 0.8792002201080322, + 0.8812903761863708, + 0.8788249492645264 + ] +} \ No newline at end of file diff --git a/training_history/2025-08-07_16-25-27.png b/training_history/2025-08-07_16-25-27.png new file mode 100644 index 0000000000000000000000000000000000000000..74137d3865c80baa78c417b5b7928bba8e5f5728 Binary files /dev/null and b/training_history/2025-08-07_16-25-27.png differ diff --git a/utils/BilinearUpSampling.py b/utils/BilinearUpSampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e99327c7e97bb980280a123217097ddb2b9d99e1 --- /dev/null +++ b/utils/BilinearUpSampling.py @@ -0,0 +1,92 @@ +import keras.backend as K +import tensorflow as tf +from keras.layers import * + +def resize_images_bilinear(X, height_factor=1, width_factor=1, target_height=None, target_width=None, data_format='default'): + '''Resizes the images contained in a 4D tensor of shape + - [batch, channels, height, width] (for 'channels_first' data_format) + - [batch, height, width, channels] (for 'channels_last' data_format) + by a factor of (height_factor, width_factor). Both factors should be + positive integers. + ''' + if data_format == 'default': + data_format = K.image_data_format() + if data_format == 'channels_first': + original_shape = K.int_shape(X) + if target_height and target_width: + new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) + else: + new_shape = tf.shape(X)[2:] + new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) + X = K.permute_dimensions(X, [0, 2, 3, 1]) + X = tf.image.resize_bilinear(X, new_shape) + X = K.permute_dimensions(X, [0, 3, 1, 2]) + if target_height and target_width: + X.set_shape((None, None, target_height, target_width)) + else: + X.set_shape((None, None, original_shape[2] * height_factor, original_shape[3] * width_factor)) + return X + elif data_format == 'channels_last': + original_shape = K.int_shape(X) + if target_height and target_width: + new_shape = tf.constant(np.array((target_height, target_width)).astype('int32')) + else: + new_shape = tf.shape(X)[1:3] + new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('int32')) + X = tf.image.resize_bilinear(X, new_shape) + if target_height and target_width: + X.set_shape((None, target_height, target_width, None)) + else: + X.set_shape((None, original_shape[1] * height_factor, original_shape[2] * width_factor, None)) + return X + else: + raise Exception('Invalid data_format: ' + data_format) + +class BilinearUpSampling2D(Layer): + def __init__(self, size=(1, 1), target_size=None, data_format='default', **kwargs): + if data_format == 'default': + data_format = K.image_data_format() + self.size = tuple(size) + if target_size is not None: + self.target_size = tuple(target_size) + else: + self.target_size = None + assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {tf, th}' + self.data_format = data_format + self.input_spec = [InputSpec(ndim=4)] + super(BilinearUpSampling2D, self).__init__(**kwargs) + + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_first': + width = int(self.size[0] * input_shape[2] if input_shape[2] is not None else None) + height = int(self.size[1] * input_shape[3] if input_shape[3] is not None else None) + if self.target_size is not None: + width = self.target_size[0] + height = self.target_size[1] + return (input_shape[0], + input_shape[1], + width, + height) + elif self.data_format == 'channels_last': + width = int(self.size[0] * input_shape[1] if input_shape[1] is not None else None) + height = int(self.size[1] * input_shape[2] if input_shape[2] is not None else None) + if self.target_size is not None: + width = self.target_size[0] + height = self.target_size[1] + return (input_shape[0], + width, + height, + input_shape[3]) + else: + raise Exception('Invalid data_format: ' + self.data_format) + + def call(self, x, mask=None): + if self.target_size is not None: + return resize_images_bilinear(x, target_height=self.target_size[0], target_width=self.target_size[1], data_format=self.data_format) + else: + return resize_images_bilinear(x, height_factor=self.size[0], width_factor=self.size[1], data_format=self.data_format) + + def get_config(self): + config = {'size': self.size, 'target_size': self.target_size} + base_config = super(BilinearUpSampling2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/BilinearUpSampling.cpython-37.pyc b/utils/__pycache__/BilinearUpSampling.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67538a217261825cb683a40f07abbdfd92bad3d2 Binary files /dev/null and b/utils/__pycache__/BilinearUpSampling.cpython-37.pyc differ diff --git a/utils/__pycache__/BilinearUpSampling.cpython-39.pyc b/utils/__pycache__/BilinearUpSampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6570fc233365760d6e8c58b09d20c1b1050042f0 Binary files /dev/null and b/utils/__pycache__/BilinearUpSampling.cpython-39.pyc differ diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57309b319954cdf798275fc2f67cd4538c76f09e Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-313.pyc b/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3badae41d37fc05b2ec4a04a442126edaa19f19 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/utils/__pycache__/__init__.cpython-37.pyc b/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28c3bd61a2b0b755370d18ec9f20e548ab48a549 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c7559cffb728d4799b33fad81d64f1ecb8d2d02 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/augment.py b/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..3b65ad8131f20387b7510e43fa4467573860860f --- /dev/null +++ b/utils/augment.py @@ -0,0 +1,10 @@ +import Augmentor + +p = Augmentor.Pipeline("./data/train/images") +p.ground_truth("./data/train/labels") +p.rotate(probability=0.7, max_left_rotation=25, max_right_rotation=25) +p.flip_left_right(probability=0.5) +p.zoom_random(probability=0.5, percentage_area=0.8) +p.flip_top_bottom(probability=0.5) +p.set_save_format(save_format='auto') +p.sample(10000, multi_threaded=False) \ No newline at end of file diff --git a/utils/config/memory.py b/utils/config/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfcab3fe6e62bd6c5d353fbfa603ae00036edcc --- /dev/null +++ b/utils/config/memory.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------ # +# +# file : utils/config/read.py +# author : ZFTurbo +# Calculate the memory needed to run the model +# +# ------------------------------------------------------------ # + +def get_model_memory_usage(batch_size, model): + import numpy as np + from keras import backend as K + + shapes_mem_count = 0 + for l in model.layers: + single_layer_mem = 1 + for s in l.output_shape: + if s is None: + continue + single_layer_mem *= s + shapes_mem_count += single_layer_mem + + trainable_count = np.sum([K.count_params(p) for p in set(model.trainable_weights)]) + non_trainable_count = np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]) + + number_size = 4.0 + if K.floatx() == 'float16': + number_size = 2.0 + if K.floatx() == 'float64': + number_size = 8.0 + + total_memory = number_size*(batch_size*shapes_mem_count + trainable_count + non_trainable_count) + gbytes = np.round(total_memory / (1024.0 ** 3), 3) + return gbytes \ No newline at end of file diff --git a/utils/config/read.py b/utils/config/read.py new file mode 100644 index 0000000000000000000000000000000000000000..21166b12e430244516f2ed4e63cc9d76513c6603 --- /dev/null +++ b/utils/config/read.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------ # +# +# file : utils/config/read.py +# author : CM +# Read the configuration +# +# ------------------------------------------------------------ # + +import configparser + + +def readConfig(filename): + + # ----- Read the configuration ---- + config = configparser.RawConfigParser() + config.read_file(open(filename)) + + dataset_in_path = config.get("dataset", "in_path") + dataset_gd_path = config.get("dataset", "gd_path") + + dataset_train = int(config.get("dataset", "train")) + dataset_valid = int(config.get("dataset", "valid")) + dataset_test = int(config.get("dataset", "test")) + + + train_patch_size_x = int(config.get("train", "patch_size_x")) + train_patch_size_y = int(config.get("train", "patch_size_y")) + train_patch_size_z = int(config.get("train", "patch_size_z")) + + train_batch_size = int(config.get("train", "batch_size")) + train_steps_per_epoch = int(config.get("train", "steps_per_epoch")) + train_epochs = int(config.get("train", "epochs")) + + logs_path = config.get("train", "logs_path") + + + return {"dataset_in_path": dataset_in_path, + "dataset_gd_path": dataset_gd_path, + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + "dataset_test": dataset_test, + "train_patch_size_x": train_patch_size_x, + "train_patch_size_y": train_patch_size_y, + "train_patch_size_z": train_patch_size_z, + "train_batch_size": train_batch_size, + "train_steps_per_epoch": train_steps_per_epoch, + "train_epochs": train_epochs, + "logs_path": logs_path + } + +# Old version will be deleted soon +def readConfig_OLD(filename): + + # ----- Read the configuration ---- + config = configparser.RawConfigParser() + config.read_file(open(filename)) + + dataset_train_size = int(config.get("dataset","train_size")) + dataset_train_gd_path = config.get("dataset","train_gd_path") + dataset_train_mra_path = config.get("dataset","train_mra_path") + + dataset_valid_size = int(config.get("dataset","valid_size")) + dataset_valid_gd_path = config.get("dataset","valid_gd_path") + dataset_valid_mra_path = config.get("dataset","valid_mra_path") + + dataset_test_size = int(config.get("dataset","test_size")) + dataset_test_gd_path = config.get("dataset","test_gd_path") + dataset_test_mra_path = config.get("dataset","test_mra_path") + + image_size_x = int(config.get("data","image_size_x")) + image_size_y = int(config.get("data","image_size_y")) + image_size_z = int(config.get("data","image_size_z")) + + patch_size_x = int(config.get("patchs","patch_size_x")) + patch_size_y = int(config.get("patchs","patch_size_y")) + patch_size_z = int(config.get("patchs","patch_size_z")) + + batch_size = int(config.get("train","batch_size")) + steps_per_epoch = int(config.get("train","steps_per_epoch")) + epochs = int(config.get("train","epochs")) + + logs_folder = config.get("logs","folder") + + return {"dataset_train_size" : dataset_train_size, + "dataset_train_gd_path" : dataset_train_gd_path, + "dataset_train_mra_path": dataset_train_mra_path, + "dataset_valid_size" : dataset_valid_size, + "dataset_valid_gd_path" : dataset_valid_gd_path, + "dataset_valid_mra_path": dataset_valid_mra_path, + "dataset_test_size" : dataset_test_size, + "dataset_test_gd_path" : dataset_test_gd_path, + "dataset_test_mra_path" : dataset_test_mra_path, + "image_size_x" : image_size_x, + "image_size_y" : image_size_y, + "image_size_z" : image_size_z, + "patch_size_x" : patch_size_x, + "patch_size_y" : patch_size_y, + "patch_size_z" : patch_size_z, + "batch_size" : batch_size, + "steps_per_epoch" : steps_per_epoch, + "epochs" : epochs, + "logs_folder" : logs_folder + } \ No newline at end of file diff --git a/utils/io/__init__.py b/utils/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/io/__pycache__/__init__.cpython-310.pyc b/utils/io/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fe2f003fe04e93022cb986c8caff71c37931a24 Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/io/__pycache__/__init__.cpython-313.pyc b/utils/io/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06307faf7974df0d75ff15439458146f9daca008 Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-313.pyc differ diff --git a/utils/io/__pycache__/__init__.cpython-37.pyc b/utils/io/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a30f6120b4a96d1fdf7d2affeedbcc6385031f3 Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-37.pyc differ diff --git a/utils/io/__pycache__/__init__.cpython-39.pyc b/utils/io/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f00aefe73934dfde844e2616b61a3720014f41ea Binary files /dev/null and b/utils/io/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/io/__pycache__/data.cpython-310.pyc b/utils/io/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1741d3fd554a4ec3084e9f1e4a46fce5863dc3b1 Binary files /dev/null and b/utils/io/__pycache__/data.cpython-310.pyc differ diff --git a/utils/io/__pycache__/data.cpython-313.pyc b/utils/io/__pycache__/data.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48748fb92c1c6039ce9522f8d761d5e3c031caca Binary files /dev/null and b/utils/io/__pycache__/data.cpython-313.pyc differ diff --git a/utils/io/__pycache__/data.cpython-37.pyc b/utils/io/__pycache__/data.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc8ed5077d79a1646c9937a52e8ce60f1fe302b Binary files /dev/null and b/utils/io/__pycache__/data.cpython-37.pyc differ diff --git a/utils/io/__pycache__/data.cpython-39.pyc b/utils/io/__pycache__/data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6ec35cca0c9521ae45c74ad4683a0dcdd857e6 Binary files /dev/null and b/utils/io/__pycache__/data.cpython-39.pyc differ diff --git a/utils/io/data.py b/utils/io/data.py new file mode 100644 index 0000000000000000000000000000000000000000..7845d12a8ceb0826dc01ac8c34b6460cc803cb3e --- /dev/null +++ b/utils/io/data.py @@ -0,0 +1,238 @@ +import os +import cv2 +import json +import random +import datetime +import numpy as np +import matplotlib.pyplot as plt + + +class DataGen: + + def __init__(self, path, split_ratio, x, y, color_space='rgb'): + self.x = x + self.y = y + self.path = path + self.color_space = color_space + self.path_train_images = path + "train/images/" + self.path_train_labels = path + "train/labels/" + self.path_test_images = path + "test/images/" + self.path_test_labels = path + "test/labels/" + self.image_file_list = get_png_filename_list(self.path_train_images) + self.label_file_list = get_png_filename_list(self.path_train_labels) + self.image_file_list[:], self.label_file_list[:] = self.shuffle_image_label_lists_together() + self.split_index = int(split_ratio * len(self.image_file_list)) + self.x_train_file_list = self.image_file_list[self.split_index:] + self.y_train_file_list = self.label_file_list[self.split_index:] + self.x_val_file_list = self.image_file_list[:self.split_index] + self.y_val_file_list = self.label_file_list[:self.split_index] + self.x_test_file_list = get_png_filename_list(self.path_test_images) + self.y_test_file_list = get_png_filename_list(self.path_test_labels) + + def generate_data(self, batch_size, train=False, val=False, test=False): + """Replaces Keras' native ImageDataGenerator.""" + try: + if train is True: + image_file_list = self.x_train_file_list + label_file_list = self.y_train_file_list + elif val is True: + image_file_list = self.x_val_file_list + label_file_list = self.y_val_file_list + elif test is True: + image_file_list = self.x_test_file_list + label_file_list = self.y_test_file_list + except ValueError: + print('one of train or val or test need to be True') + + i = 0 + while True: + image_batch = [] + label_batch = [] + for b in range(batch_size): + if i == len(self.x_train_file_list): + i = 0 + if i < len(image_file_list): + sample_image_filename = image_file_list[i] + sample_label_filename = label_file_list[i] + # print('image: ', image_file_list[i]) + # print('label: ', label_file_list[i]) + if train or val: + image = cv2.imread(self.path_train_images + sample_image_filename, 1) + label = cv2.imread(self.path_train_labels + sample_label_filename, 0) + elif test is True: + image = cv2.imread(self.path_test_images + sample_image_filename, 1) + label = cv2.imread(self.path_test_labels + sample_label_filename, 0) + # image, label = self.change_color_space(image, label, self.color_space) + label = np.expand_dims(label, axis=2) + if image.shape[0] == self.x and image.shape[1] == self.y: + image_batch.append(image.astype("float32")) + else: + print('the input image shape is not {}x{}'.format(self.x, self.y)) + if label.shape[0] == self.x and label.shape[1] == self.y: + label_batch.append(label.astype("float32")) + else: + print('the input label shape is not {}x{}'.format(self.x, self.y)) + i += 1 + if image_batch and label_batch: + image_batch = normalize(np.array(image_batch)) + label_batch = normalize(np.array(label_batch)) + yield (image_batch, label_batch) + + def get_num_data_points(self, train=False, val=False): + try: + image_file_list = self.x_train_file_list if val is False and train is True else self.x_val_file_list + except ValueError: + print('one of train or val need to be True') + + return len(image_file_list) + + def shuffle_image_label_lists_together(self): + combined = list(zip(self.image_file_list, self.label_file_list)) + random.shuffle(combined) + return zip(*combined) + + @staticmethod + def change_color_space(image, label, color_space): + color_space = color_space.lower() + if color_space == 'hsi' or color_space == 'hsv': + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + label = cv2.cvtColor(label, cv2.COLOR_BGR2HSV) + elif color_space == 'lab': + image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) + label = cv2.cvtColor(label, cv2.COLOR_BGR2LAB) + return image, label +def normalize(arr): + diff = np.amax(arr) - np.amin(arr) + diff = 255 if diff == 0 else diff + arr = arr / np.absolute(diff) + return arr + + +def get_png_filename_list(path): + file_list = [] + for FileNameLength in range(0, 500): + for dirName, subdirList, fileList in os.walk(path): + for filename in fileList: + # check file extension + if ".png" in filename.lower() and len(filename) == FileNameLength: + file_list.append(filename) + break + file_list.sort() + return file_list + + +def get_jpg_filename_list(path): + file_list = [] + for FileNameLength in range(0, 500): + for dirName, subdirList, fileList in os.walk(path): + for filename in fileList: + # check file extension + if ".jpg" in filename.lower() and len(filename) == FileNameLength: + file_list.append(filename) + break + file_list.sort() + return file_list + + +def load_jpg_images(path): + file_list = get_jpg_filename_list(path) + temp_list = [] + for filename in file_list: + img = cv2.imread(path + filename, 1) + temp_list.append(img.astype("float32")) + + temp_list = np.array(temp_list) + # x_train = np.reshape(x_train,(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)) + return temp_list, file_list + + +def load_png_images(path): + + temp_list = [] + file_list = get_png_filename_list(path) + for filename in file_list: + img = cv2.imread(path + filename, 1) + temp_list.append(img.astype("float32")) + + temp_list = np.array(temp_list) + #temp_list = np.reshape(temp_list,(temp_list.shape[0], temp_list.shape[1], temp_list.shape[2], 3)) + return temp_list, file_list + + +def load_data(path): + # path_train_images = path + "train/images/padded/" + # path_train_labels = path + "train/labels/padded/" + # path_test_images = path + "test/images/padded/" + # path_test_labels = path + "test/labels/padded/" + path_train_images = path + "train/images/" + path_train_labels = path + "train/labels/" + path_test_images = path + "test/images/" + path_test_labels = path + "test/labels/" + x_train, train_image_filenames_list = load_png_images(path_train_images) + y_train, train_label_filenames_list = load_png_images(path_train_labels) + x_test, test_image_filenames_list = load_png_images(path_test_images) + y_test, test_label_filenames_list = load_png_images(path_test_labels) + x_train = normalize(x_train) + y_train = normalize(y_train) + x_test = normalize(x_test) + y_test = normalize(y_test) + return x_train, y_train, x_test, y_test, test_label_filenames_list + + +def load_test_images(path): + path_test_images = path + "test/images/" + x_test, test_image_filenames_list = load_png_images(path_test_images) + x_test = normalize(x_test) + return x_test, test_image_filenames_list + + +def save_results(np_array, color_space, outpath, test_label_filenames_list): + i = 0 + for filename in test_label_filenames_list: + # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1])) + pred = np_array[i] + # if color_space.lower() is 'hsi' or 'hsv': + # pred = cv2.cvtColor(pred, cv2.COLOR_HSV2RGB) + # elif color_space.lower() is 'lab': + # pred = cv2.cvtColor(pred, cv2.COLOR_Lab2RGB) + cv2.imwrite(outpath + filename, pred * 255.) + i += 1 + + +def save_rgb_results(np_array, outpath, test_label_filenames_list): + i = 0 + for filename in test_label_filenames_list: + # predict_img = np.reshape(predict_img,(predict_img[0],predict_img[1])) + cv2.imwrite(outpath + filename, np_array[i] * 255.) + i += 1 + + +def save_history(model, model_name, training_history, dataset, n_filters, epoch, learning_rate, loss, + color_space, path=None, temp_name=None): + save_weight_filename = temp_name if temp_name else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + model.save('{}{}.hdf5'.format(path, save_weight_filename)) + with open('{}{}.json'.format(path, save_weight_filename), 'w') as f: + json.dump(training_history.history, f, indent=2) + + json_list = ['{}{}.json'.format(path, save_weight_filename)] + for json_filename in json_list: + with open(json_filename) as f: + # convert the loss json object to a python dict + loss_dict = json.load(f) + print_list = ['loss', 'val_loss', 'dice_coef', 'val_dice_coef'] + for item in print_list: + item_list = [] + if item in loss_dict: + item_list.extend(loss_dict.get(item)) + plt.plot(item_list) + plt.title('model:{} lr:{} epoch:{} #filtr:{} Colorspaces:{}'.format(model_name, learning_rate, + epoch, n_filters, color_space)) + plt.ylabel('loss') + plt.xlabel('epoch') + plt.legend(['train_loss', 'test_loss', 'train_dice', 'test_dice'], loc='upper left') + plt.savefig('{}{}.png'.format(path, save_weight_filename)) + plt.show() + plt.clf() + + + diff --git a/utils/io/read.py b/utils/io/read.py new file mode 100644 index 0000000000000000000000000000000000000000..86f41724e6e8ed5c7e44dcdaeb5d37d081cd5fe8 --- /dev/null +++ b/utils/io/read.py @@ -0,0 +1,173 @@ +# ------------------------------------------------------------ # +# +# file : utils/io/read.py +# author : CM +# Function to read dataset +# +# ------------------------------------------------------------ # + +import os +import sys + +import nibabel as nib +import numpy as np + +# read nii file and load it into a numpy 3d array +def niiToNp(filename): + data = nib.load(filename).get_data().astype('float16') + return data/data.max() + +# read a dataset and load it into a numpy 4d array +def readDataset(folder, size, size_x, size_y, size_z): + dataset = np.empty((size, size_x, size_y, size_z), dtype='float16') + i = 0 + files = os.listdir(folder) + files.sort() + for filename in files: + if(i>=size): + break + print(filename) + dataset[i, :, :, :] = niiToNp(os.path.join(folder, filename)) + i = i+1 + + return dataset + +# return dataset affine +def getAffine_subdir(folder): + subdir = os.listdir(folder) + subdir.sort() + files = os.listdir(folder+subdir[0]) + path = folder + subdir[0] + image = nib.load(os.path.join(path, files[0])) + return image.affine + +def getAffine(folder): + files = os.listdir(folder) + files.sort() + image = nib.load(os.path.join(folder, files[0])) + return image.affine + + + +# reshape the dataset to match keras input shape (add channel dimension) +def reshapeDataset(d): + return d.reshape(d.shape[0], d.shape[1], d.shape[2], d.shape[3], 1) + +# read a dataset and load it into a numpy 3d array as raw data (no normalisation) +def readRawDataset(folder, size, size_x, size_y, size_z, dtype): + files = os.listdir(folder) + files.sort() + + if(len(files) < size): + sys.exit(2) + + count = 0 + # astype depend on your dataset type. + dataset = np.empty((size, size_x, size_y, size_z)).astype(dtype) + + for filename in files: + if(count>=size): + break + dataset[count, :, :, :] = nib.load(os.path.join(folder, filename)).get_data() + count += 1 + print(count, '/', size, os.path.join(folder, filename)) + + return dataset + +def readTrainValid(config): + print("Loading training dataset") + + train_gd_dataset = readRawDataset(config["dataset_train_gd_path"], + config["dataset_train_size"], + config["image_size_x"], + config["image_size_y"], + config["image_size_z"], + 'uint16') + + print("Training ground truth dataset shape", train_gd_dataset.shape) + print("Training ground truth dataset dtype", train_gd_dataset.dtype) + + train_in_dataset = readRawDataset(config["dataset_train_mra_path"], + config["dataset_train_size"], + config["image_size_x"], + config["image_size_y"], + config["image_size_z"], + 'uint16') + + print("Training input image dataset shape", train_in_dataset.shape) + print("Training input image dataset dtype", train_in_dataset.dtype) + + print("Loading validation dataset") + + valid_gd_dataset = readRawDataset(config["dataset_valid_gd_path"], + config["dataset_valid_size"], + config["image_size_x"], + config["image_size_y"], + config["image_size_z"], + 'uint16') + + print("Validation ground truth dataset shape", valid_gd_dataset.shape) + print("Validation ground truth dataset dtype", valid_gd_dataset.dtype) + + valid_in_dataset = readRawDataset(config["dataset_valid_mra_path"], + config["dataset_valid_size"], + config["image_size_x"], + config["image_size_y"], + config["image_size_z"], + 'uint16') + + print("Validation input image dataset shape", valid_in_dataset.shape) + print("Validation input image dataset dtype", valid_in_dataset.dtype) + return train_gd_dataset, train_in_dataset, valid_gd_dataset, valid_in_dataset + +# read a dataset and load it into a numpy 3d without any preprocessing +def getDataset(folder, size, type=None): + files = os.listdir(folder) + files.sort() + + if(len(files) < size): + sys.exit(0x2001) + + image = nib.load(os.path.join(folder, files[0])) + + if type==None: + dtype = image.get_data_dtype() + else: + dtype = type + + dataset = np.empty((size, image.shape[0], image.shape[1], image.shape[2])).astype(dtype) + del image + + count = 0 + for filename in files: + dataset[count, :, :, :] = nib.load(os.path.join(folder, filename)).get_data() + count += 1 + if(count>=size): + break + + return dataset + +# read a dataset and load it into a numpy 3d without any preprocessing with "start" index and number of files +def readDatasetPart(folder, start, size, type=None): + files = os.listdir(folder) + files.sort() + + if(len(files) < start + size): + sys.exit("readDatasetPart : len(files) < start + size") + + image = nib.load(os.path.join(folder, files[0])) + + if type==None: + dtype = image.get_data_dtype() + else: + dtype = type + + dataset = np.empty(((size), image.shape[0], image.shape[1], image.shape[2])).astype(dtype) + del image + + count = 0 + for i in range(start, start + size): + dataset[count, :, :, :] = nib.load(os.path.join(folder, files[i])).get_data() + count += 1 + + return dataset diff --git a/utils/io/write.py b/utils/io/write.py new file mode 100644 index 0000000000000000000000000000000000000000..88cf8386a62618aba4c6b2dfdfb5fe7bb5f852ec --- /dev/null +++ b/utils/io/write.py @@ -0,0 +1,23 @@ +# ------------------------------------------------------------ # +# +# file : utils/io/write.py +# author : CM +# Function to write results +# +# ------------------------------------------------------------ # + +import nibabel as nib +import numpy as np + +# write nii file from a numpy 3d array +def npToNii(data, filename): + axes = np.eye(4) + axes[0][0] = -1 + axes[1][1] = -1 + image = nib.Nifti1Image(data, axes) + nib.save(image, filename) + +# write nii file from a numpy 3d array with affine configuration +def npToNiiAffine(data, affine, filename): + image = nib.Nifti1Image(data, affine) + nib.save(image, filename) \ No newline at end of file diff --git a/utils/learning/__pycache__/losses.cpython-37.pyc b/utils/learning/__pycache__/losses.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383ef9c84b97d53194604da9da2701b7ed5bb941 Binary files /dev/null and b/utils/learning/__pycache__/losses.cpython-37.pyc differ diff --git a/utils/learning/__pycache__/losses.cpython-39.pyc b/utils/learning/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39ce00e24787467c070845a5c5d27a8188338831 Binary files /dev/null and b/utils/learning/__pycache__/losses.cpython-39.pyc differ diff --git a/utils/learning/__pycache__/metrics.cpython-310.pyc b/utils/learning/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47e8e183fe484542b41ba863d6a4a91e6ea1361c Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-310.pyc differ diff --git a/utils/learning/__pycache__/metrics.cpython-313.pyc b/utils/learning/__pycache__/metrics.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbb302cc28e50af6d1ec38630dd3608290f0b30a Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-313.pyc differ diff --git a/utils/learning/__pycache__/metrics.cpython-37.pyc b/utils/learning/__pycache__/metrics.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5050e934bcd9313efbaa2e1f736307a43650e0f6 Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-37.pyc differ diff --git a/utils/learning/__pycache__/metrics.cpython-39.pyc b/utils/learning/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a71265dbb00586a0cbdefa64882c085c0b81ae3 Binary files /dev/null and b/utils/learning/__pycache__/metrics.cpython-39.pyc differ diff --git a/utils/learning/callbacks.py b/utils/learning/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..6726caf89120752c885b50d57f3c80a189efe1ab --- /dev/null +++ b/utils/learning/callbacks.py @@ -0,0 +1,18 @@ +# ------------------------------------------------------------ # +# +# file : utils/learning/callbacks.py +# author : CM +# Custom callbacks +# +# ------------------------------------------------------------ # + +import numpy as np +from keras.callbacks import LearningRateScheduler + +# reduce learning rate on each epoch +def learningRateSchedule(initialLr=1e-4, decayFactor=0.99, stepSize=1): + def schedule(epoch): + lr = initialLr * (decayFactor ** np.floor(epoch / stepSize)) + print("Learning rate : ", lr) + return lr + return LearningRateScheduler(schedule) \ No newline at end of file diff --git a/utils/learning/losses.py b/utils/learning/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d811986771010537ef591126fad399c2a4cfa29f --- /dev/null +++ b/utils/learning/losses.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------ # +# +# file : losses.py +# author : CM +# Loss function +# +# ------------------------------------------------------------ # +import keras.backend as K + +def dice_coef(y_true, y_pred, smooth=1): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + +def dice_coef_loss(y_true, y_pred): + return -dice_coef(y_true, y_pred) + +# Jaccard distance +def jaccard_distance_loss(y_true, y_pred, smooth=100): + """ + Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) + = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) + + The jaccard distance loss is usefull for unbalanced datasets. This has been + shifted so it converges on 0 and is smoothed to avoid exploding or disapearing + gradient. + + Ref: https://en.wikipedia.org/wiki/Jaccard_index + + @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 + @author: wassname + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) + jac = (intersection + smooth) / (sum_ - intersection + smooth) + return (1 - jac) * smooth + + +def dice_coef_(y_true, y_pred, smooth=1): + """ + Dice = (2*|X & Y|)/ (|X|+ |Y|) + = 2*sum(|A*B|)/(sum(A^2)+sum(B^2)) + ref: https://arxiv.org/pdf/1606.04797v1.pdf + """ + intersection = K.sum(K.abs(y_true * y_pred), axis=-1) + return (2. * intersection + smooth) / (K.sum(K.square(y_true), -1) + K.sum(K.square(y_pred), -1) + smooth) + +def dice_coef_loss_(y_true, y_pred): + return 1 - dice_coef_(y_true, y_pred) + +''' +def dice_loss(y_true, y_pred, smooth=1e-6): + """ Loss function base on dice coefficient. + + Parameters + ---------- + y_true : keras tensor + tensor containing target mask. + y_pred : keras tensor + tensor containing predicted mask. + smooth : float + small real value used for avoiding division by zero error. + + Returns + ------- + keras tensor + tensor containing dice loss. + """ + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + answer = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + return -answer +''' +# the deeplab version of dice_loss +def dice_loss(y_true, y_pred): + smooth = 1. + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = y_true_f * y_pred_f + score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + return 1. - score + +def tversky_loss(y_true, y_pred, alpha=0.3, beta=0.7, smooth=1e-10): + """ Tversky loss function. + + Parameters + ---------- + y_true : keras tensor + tensor containing target mask. + y_pred : keras tensor + tensor containing predicted mask. + alpha : float + real value, weight of '0' class. + beta : float + real value, weight of '1' class. + smooth : float + small real value used for avoiding division by zero error. + + Returns + ------- + keras tensor + tensor containing tversky loss. + """ + y_true = K.flatten(y_true) + y_pred = K.flatten(y_pred) + truepos = K.sum(y_true * y_pred) + fp_and_fn = alpha * K.sum(y_pred * (1 - y_true)) + beta * K.sum((1 - y_pred) * y_true) + answer = (truepos + smooth) / ((truepos + smooth) + fp_and_fn) + return -answer + +def jaccard_coef_logloss(y_true, y_pred, smooth=1e-10): + """ Loss function based on jaccard coefficient. + + Parameters + ---------- + y_true : keras tensor + tensor containing target mask. + y_pred : keras tensor + tensor containing predicted mask. + smooth : float + small real value used for avoiding division by zero error. + + Returns + ------- + keras tensor + tensor containing negative logarithm of jaccard coefficient. + """ + y_true = K.flatten(y_true) + y_pred = K.flatten(y_pred) + truepos = K.sum(y_true * y_pred) + falsepos = K.sum(y_pred) - truepos + falseneg = K.sum(y_true) - truepos + jaccard = (truepos + smooth) / (smooth + truepos + falseneg + falsepos) + return -K.log(jaccard + smooth) diff --git a/utils/learning/metrics.py b/utils/learning/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5789bf29f501a13550d7da136669239e7cc317 --- /dev/null +++ b/utils/learning/metrics.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------ # +# +# file : metrics.py +# author : CM +# Metrics for evaluation +# +# ------------------------------------------------------------ # +from keras import backend as K + + +# dice coefficient +''' +def dice_coef(y_true, y_pred, smooth=1.): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) +''' +# the deeplab version of dice coefficient +def dice_coef(y_true, y_pred): + smooth = 0.00001 + y_true_f = K.flatten(y_true) + y_pred = K.cast(y_pred, 'float32') + y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32') + intersection = y_true_f * y_pred_f + score = (2. * K.sum(intersection) + smooth) / ((K.sum(y_true_f) + K.sum(y_pred_f)) + smooth) + return score + + +# Recall (true positive rate) +def recall(truth, prediction): + TP = K.sum(K.round(K.clip(truth * prediction, 0, 1))) + P = K.sum(K.round(K.clip(truth, 0, 1))) + return TP / (P + K.epsilon()) + + +# Specificity (true negative rate) +def specificity(truth, prediction): + TN = K.sum(K.round(K.clip((1-truth) * (1-prediction), 0, 1))) + N = K.sum(K.round(K.clip(1-truth, 0, 1))) + return TN / (N + K.epsilon()) + + +# Precision (positive prediction value) +def precision(truth, prediction): + TP = K.sum(K.round(K.clip(truth * prediction, 0, 1))) + FP = K.sum(K.round(K.clip((1-truth) * prediction, 0, 1))) + return TP / (TP + FP + K.epsilon()) + + +def f1(y_true, y_pred): + def recall(y_true, y_pred): + """Recall metric. + + Only computes a batch-wise average of recall. + + Computes the recall, a metric for multi-label classification of + how many relevant items are selected. + """ + true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) + possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) + recall = true_positives / (possible_positives + K.epsilon()) + return recall + + def precision(y_true, y_pred): + """Precision metric. + + Only computes a batch-wise average of precision. + + Computes the precision, a metric for multi-label classification of + how many selected items are relevant. + """ + true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) + predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1))) + precision = true_positives / (predicted_positives + K.epsilon()) + return precision + + precision = precision(y_true, y_pred) + recall = recall(y_true, y_pred) + + return 2*((precision*recall)/(precision+recall+K.epsilon())) diff --git a/utils/learning/patch/extraction.py b/utils/learning/patch/extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..d260a558a0a85f12b564100a12645295c24902b7 --- /dev/null +++ b/utils/learning/patch/extraction.py @@ -0,0 +1,279 @@ +# ------------------------------------------------------------ # +# +# file : utils/learning/patch/extraction.py +# author : CM +# Function to extract patch from input dataset +# +# ------------------------------------------------------------ # +import sys +from random import randint +import numpy as np +from keras.utils import to_categorical + +# ----- Patch Extraction ----- +# -- Single Patch +# exctract a patch from an image +def extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z): + patch = d[x:x+patch_size_x,y:y+patch_size_y,z:z+patch_size_z] + return patch + +# extract a patch from an image. The patch can be out of the image (0 padding) +def extractPatchOut(d, patch_size_x, patch_size_y, patch_size_z, x_, y_, z_): + patch = np.zeros((patch_size_x, patch_size_y, patch_size_z), dtype='float16') + for x in range(0,patch_size_x): + for y in range(0, patch_size_y): + for z in range(0, patch_size_z): + if(x+x_ >= 0 and x+x_ < d.shape[0] and y+y_ >= 0 and y+y_ < d.shape[1] and z+z_ >= 0 and z+z_ < d.shape[2]): + patch[x,y,z] = d[x+x_,y+y_,z+z_] + return patch + +# create random patch for an image +def generateRandomPatch(d, patch_size_x, patch_size_y, patch_size_z): + x = randint(0, d.shape[0]-patch_size_x) + y = randint(0, d.shape[1]-patch_size_y) + z = randint(0, d.shape[2]-patch_size_z) + data = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z) + return data + +# -- Multiple Patchs +# create random patchs for an image +def generateRandomPatchs(d, patch_size_x, patch_size_y, patch_size_z, patch_number): + # max_patch_nb = (d.shape[0]-patch_size_x)*(d.shape[1]-patch_size_y)*(d.shape[2]-patch_size_z) + data = np.empty((patch_number, patch_size_x, patch_size_y, patch_size_z), dtype='float16') + + for i in range(0,patch_number): + data[i] = generateRandomPatch(d, patch_size_x, patch_size_y, patch_size_z) + + return data + +# divide the full image into patchs +# todo : missing data if shape%patch_size is not 0 +def generateFullPatchs(d, patch_size_x, patch_size_y, patch_size_z): + patch_nb = int((d.shape[0]/patch_size_x)*(d.shape[1]/patch_size_y)*(d.shape[2]/patch_size_z)) + data = np.empty((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16') + i = 0 + for x in range(0,d.shape[0], patch_size_x): + for y in range(0, d.shape[1], patch_size_y): + for z in range(0,d.shape[2], patch_size_z): + data[i] = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z) + i = i+1 + + return data + +def generateFullPatchsPlus(d, patch_size_x, patch_size_y, patch_size_z, dx, dy, dz): + patch_nb = int((d.shape[0]/dx)*(d.shape[1]/dy)*(d.shape[2]/dz)) + data = np.empty((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16') + i = 0 + for x in range(0,d.shape[0]-dx, dx): + for y in range(0, d.shape[1]-dy, dy): + for z in range(0,d.shape[2]-dz, dz): + data[i] = extractPatch(d, patch_size_x, patch_size_y, patch_size_z, x, y, z) + i = i+1 + + return data + +def noNeg(x): + if(x>0): + return x + else: + return 0 + +def generateFullPatchsCentered(d, patch_size_x, patch_size_y, patch_size_z): + patch_nb = int(2*(d.shape[0]/patch_size_x)*2*(d.shape[1]/patch_size_y)*2*(d.shape[2]/patch_size_z)) + data = np.zeros((patch_nb, patch_size_x, patch_size_y, patch_size_z), dtype='float16') + i = 0 + psx = int(patch_size_x/2) + psy = int(patch_size_y/2) + psz = int(patch_size_z/2) + for x in range(-int(patch_size_x/4),d.shape[0]-3*int(patch_size_x/4)+1, psx): + for y in range(-int(patch_size_y/4), d.shape[1]-3*int(patch_size_y/4)+1, psy): + for z in range(-int(patch_size_z/4),d.shape[2]-3*int(patch_size_z/4)+1, psz): + # patch = np.zeros((psx,psy,psz), dtype='float16') + # patch = d[noNeg(x):noNeg(x)+patch_size_x,noNeg(y):noNeg(y)+patch_size_y,noNeg(z):noNeg(z)+patch_size_z] + patch = extractPatchOut(d,patch_size_x, patch_size_y, patch_size_z, x, y, z) + data[i] = patch + i = i+1 + return data + +# ----- Patch Extraction Generator ----- +# Generator of random patchs of size 32*32*32 +def generatorRandomPatchs(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z): + batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype='float16') + batch_labels = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, labels.shape[4]), dtype='float16') + + while True: + for i in range(batch_size): + id = randint(0,features.shape[0]-1) + x = randint(0, features.shape[1]-patch_size_x) + y = randint(0, features.shape[2]-patch_size_y) + z = randint(0, features.shape[3]-patch_size_z) + + batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z) + batch_labels[i] = extractPatch(labels[id], patch_size_x, patch_size_y, patch_size_z, x, y, z) + + yield batch_features, batch_labels + +# Generator of random patchs of size 32*32*32 and 16*16*16 +def generatorRandomPatchs3216(features, labels, batch_size): + batch_features = np.zeros((batch_size, 32, 32, 32, features.shape[4]), dtype='float16') + batch_labels = np.zeros((batch_size, 16, 16, 16, labels.shape[4]), dtype='float16') + + while True: + for i in range(batch_size): + id = randint(0,features.shape[0]-1) + x = randint(0, features.shape[1]-32) + y = randint(0, features.shape[2]-32) + z = randint(0, features.shape[3]-32) + + batch_features[i] = extractPatch(features[id], 32, 32, 32, x, y, z) + batch_labels[i] = extractPatch(labels[id], 16, 16, 16, x+16, y+16, z+16) + + yield batch_features, batch_labels + +def generatorRandomPatchsLabelCentered(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z): + patch_centered_size_x = int(patch_size_x/2) + patch_centered_size_y = int(patch_size_y/2) + patch_centered_size_z = int(patch_size_z/2) + + batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype=features.dtype) + batch_labels = np.zeros((batch_size, patch_centered_size_x, patch_centered_size_y, patch_centered_size_z, + labels.shape[4]), dtype=labels.dtype) + + while True: + for i in range(batch_size): + id = randint(0,features.shape[0]-1) + x = randint(0, features.shape[1]-patch_size_x) + y = randint(0, features.shape[2]-patch_size_y) + z = randint(0, features.shape[3]-patch_size_z) + + batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z) + batch_labels[i] = extractPatch(labels[id], patch_centered_size_x, patch_centered_size_y, patch_centered_size_z, + int(x+patch_size_x/4), int(y+patch_size_y/4), int(z+patch_size_z/4)) + + yield batch_features, batch_labels + +def generatorRandomPatchsDolz(features, labels, batch_size, patch_size_x, patch_size_y, patch_size_z): + batch_features = np.zeros((batch_size, patch_size_x, patch_size_y, patch_size_z, features.shape[4]), dtype=features.dtype) + batch_labels = np.zeros((batch_size, int(patch_size_x / 2) * int(patch_size_y / 2) * int(patch_size_z / 2), 2), dtype=labels.dtype) + + while True: + for i in range(batch_size): + id = randint(0,features.shape[0]-1) + x = randint(0, features.shape[1]-patch_size_x) + y = randint(0, features.shape[2]-patch_size_y) + z = randint(0, features.shape[3]-patch_size_z) + + batch_features[i] = extractPatch(features[id], patch_size_x, patch_size_y, patch_size_z, x, y, z) + tmpPatch = extractPatch(labels[id], int(patch_size_x/2), int(patch_size_y/2), int(patch_size_z/2), + int(x+patch_size_x/4), int(y+patch_size_y/4), int(z+patch_size_z/4)) + batch_labels[i] = to_categorical(tmpPatch.flatten(),2) + """ + count = 0 + for x in range(0, tmpPatch.shape[0]): + for y in range(0, tmpPatch.shape[1]): + for z in range(0, tmpPatch.shape[2]): + if(tmpPatch[x,y,z,0] == 1): + batch_labels[i,count,1] = 1 + else: + batch_labels[i,count,0] = 1 + count += 1 + """ + yield batch_features, batch_labels + +from scipy.ndimage import zoom, rotate +# Generate random patchs with random linear transformation +# translation (random position) rotation, scale +# Preconditions : patch_features_ % patch_labels_ = 0 +# patch_features_ >= patch_labels_ +# todo : scale +def generatorRandomPatchsLinear(features, labels, patch_features_x, patch_features_y, patch_features_z, + patch_labels_x, patch_labels_y, patch_labels_z): + + patch_features = np.zeros((1, patch_features_x, patch_features_y, patch_features_z, features.shape[4]), dtype=features.dtype) + patch_labels = np.zeros((1, patch_labels_x, patch_labels_y, patch_labels_z, labels.shape[4]), dtype=labels.dtype) + + if(patch_features_x % patch_labels_x != 0 or patch_features_y % patch_labels_y != 0 or patch_features_z % patch_labels_z != 0): + sys.exit(0x00F0) + + if(patch_features_x < patch_labels_x or patch_features_y < patch_labels_y or patch_features_z < patch_labels_z): + sys.exit(0x00F1) + + # middle of patch + mx = int(patch_features_x/2) + my = int(patch_features_y/2) + mz = int(patch_features_z/2) + # patch label size/2 + sx = int(patch_labels_x / 2) + sy = int(patch_labels_y / 2) + sz = int(patch_labels_z / 2) + + while True: + id = randint(0, features.shape[0]-1) + x = randint(0, features.shape[1]-patch_features_x) + y = randint(0, features.shape[2]-patch_features_y) + z = randint(0, features.shape[3]-patch_features_z) + + # todo : check time consumtion and rotation directly on complete image + r0 = randint(0, 360)-180 + r1 = randint(0, 360)-180 + r2 = randint(0, 360)-180 + rot_features = rotate(input=features[0], angle=r0, axes=(0, 1), reshape=False) + rot_features = rotate(input=rot_features, angle=r1, axes=(1, 2), reshape=False) + rot_features = rotate(input=rot_features, angle=r2, axes=(2, 0), reshape=False) + rot_labels = rotate(input=labels[0], angle=r0, axes=(0, 1), reshape=False) + rot_labels = rotate(input=rot_labels, angle=r1, axes=(1, 2), reshape=False) + rot_labels = rotate(input=rot_labels, angle=r2, axes=(2, 0), reshape=False) + + patch_features[0] = extractPatch(rot_features, patch_features_x, patch_features_y, patch_features_z, x, y, z) + + patch_labels[0] = extractPatch(rot_labels, patch_labels_x, patch_labels_y, patch_labels_z, + x + mx - sx, y + my - sy, z + mz - sz) + + yield patch_features, patch_labels + +def randomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size): + patchs_in = np.zeros((patch_number, patch_in_size[0], patch_in_size[1], patch_in_size[2]), dtype=in_dataset.dtype) + patchs_gd = np.zeros((patch_number, patch_gd_size[0], patch_gd_size[1], patch_gd_size[2]), dtype=gd_dataset.dtype) + + if(patch_in_size[0] % patch_gd_size[0] != 0 or patch_in_size[1] % patch_gd_size[1] != 0 or patch_in_size[2] % patch_gd_size[2] != 0): + sys.exit("ERROR : randomPatchsAugmented patchs size error 1") + + if(patch_in_size[0] < patch_gd_size[0] or patch_in_size[1] < patch_gd_size[1] or patch_in_size[2] < patch_gd_size[2]): + sys.exit("ERROR : randomPatchsAugmented patchs size error 2") + + # middle of patch + mx = int(patch_in_size[0] / 2) + my = int(patch_in_size[1] / 2) + mz = int(patch_in_size[2] / 2) + # patch label size/2 + sx = int(patch_gd_size[0] / 2) + sy = int(patch_gd_size[1] / 2) + sz = int(patch_gd_size[2] / 2) + + for count in range(patch_number): + id = randint(0, in_dataset.shape[0]-1) + x = randint(0, in_dataset.shape[1]-patch_in_size[0]) + y = randint(0, in_dataset.shape[2]-patch_in_size[1]) + z = randint(0, in_dataset.shape[3]-patch_in_size[2]) + + r0 = randint(0, 3) + r1 = randint(0, 3) + r2 = randint(0, 3) + + patchs_in[count] = extractPatch(in_dataset[id], patch_in_size[0], patch_in_size[1], patch_in_size[2], x, y, z) + patchs_gd[count] = extractPatch(gd_dataset[id], patch_gd_size[0], patch_gd_size[1], patch_gd_size[2], x + mx - sx, y + my - sy, z + mz - sz) + + patchs_in[count] = np.rot90(patchs_in[count], r0, (0, 1)) + patchs_in[count] = np.rot90(patchs_in[count], r1, (1, 2)) + patchs_in[count] = np.rot90(patchs_in[count], r2, (2, 0)) + + patchs_gd[count] = np.rot90(patchs_gd[count], r0, (0, 1)) + patchs_gd[count] = np.rot90(patchs_gd[count], r1, (1, 2)) + patchs_gd[count] = np.rot90(patchs_gd[count], r2, (2, 0)) + + return patchs_in.reshape(patchs_in.shape[0], patchs_in.shape[1], patchs_in.shape[2], patchs_in.shape[3], 1),\ + patchs_gd.reshape(patchs_gd.shape[0], patchs_gd.shape[1], patchs_gd.shape[2], patchs_gd.shape[3], 1) + +def generatorRandomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size): + while True: + yield randomPatchsAugmented(in_dataset, gd_dataset, patch_number, patch_in_size, patch_gd_size) diff --git a/utils/learning/patch/reconstruction.py b/utils/learning/patch/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..d92cb504068403c27e079018b0b680d0f76dd252 --- /dev/null +++ b/utils/learning/patch/reconstruction.py @@ -0,0 +1,56 @@ +# ------------------------------------------------------------ # +# +# file : utils/learning/patch/reconstruction.py +# author : CM +# Function to reconstruct image from patch +# +# ------------------------------------------------------------ + +import numpy as np + +# ----- Image Reconstruction ----- +# Recreate the image from patchs +def fullPatchsToImage(image,patchs): + i = 0 + for x in range(0,image.shape[0], patchs.shape[1]): + for y in range(0, image.shape[1], patchs.shape[2]): + for z in range(0,image.shape[2], patchs.shape[3]): + image[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] = patchs[i,:,:,:,0] + i = i+1 + return image + +def fullPatchsPlusToImage(image,patchs, dx, dy, dz): + div = np.zeros(image.shape) + one = np.ones((patchs.shape[1],patchs.shape[2],patchs.shape[3])) + + i = 0 + for x in range(0,image.shape[0]-dx, dx): + for y in range(0, image.shape[1]-dy, dy): + for z in range(0,image.shape[2]-dz, dz): + div[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] += one + image[x:x+patchs.shape[1],y:y+patchs.shape[2],z:z+patchs.shape[3]] = patchs[i,:,:,:,0] + i = i+1 + + image = image/div + + return image + +def dolzReconstruction(image,patchs): + output = np.copy(image) + + count = 0 + + print("image shape", image.shape) + print("patchs shape", patchs.shape) + + # todo : change 16 with patch shape + + for x in range(0,image.shape[0], 16): + for y in range(0, image.shape[1], 16): + for z in range(0,image.shape[2], 16): + patch = np.argmax(patchs[count], axis=1) + patch = patch.reshape(16, 16, 16) + output[x:x+patch.shape[0],y:y+patch.shape[1],z:z+patch.shape[2]] = patch + count += 1 + + return output \ No newline at end of file diff --git a/utils/padding.py b/utils/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..567a94de5478f58641d138609bd656bcd5fc7f3b --- /dev/null +++ b/utils/padding.py @@ -0,0 +1,50 @@ +import numpy as np +from PIL import Image +import os + + +def paddingjpg(path): + xmax = 0 + ymax = 0 + file_list = [] + for FileNameLength in range(0, 100): + for dirName, subdirList, fileList in os.walk(path): + for filename in fileList: + # check file extension + if ".jpg" in filename.lower() and len(filename) == FileNameLength: + file_list.append(filename) + break + file_list.sort() + print(file_list) + temp_list = [] + for filename in file_list: + image = Image.open(path + filename) + padded_image = Image.new("RGB", [560, 560]) + padded_image.paste(image, (0,0)) + padded_image.save(path + 'padded/' + filename) + +def paddingpng(path): + xmax = 0 + ymax = 0 + file_list = [] + for FileNameLength in range(0, 100): + for dirName, subdirList, fileList in os.walk(path): + for filename in fileList: + # check file extension + if ".png" in filename.lower() and len(filename) == FileNameLength: + file_list.append(filename) + break + file_list.sort() + print(file_list) + temp_list = [] + for filename in file_list: + image = Image.open(path + filename) + padded_image = Image.new("L", [560, 560]) + padded_image.paste(image, (0,0)) + padded_image.save(path + 'padded/' + filename) + + +paddingjpg('../data/train/images/') +paddingpng('../data/train/labels/') +paddingjpg('../data/test/images/') +paddingpng('../data/test/labels/') \ No newline at end of file diff --git a/utils/postprocessing/hole_filling.py b/utils/postprocessing/hole_filling.py new file mode 100644 index 0000000000000000000000000000000000000000..ea694dda9f74fa63f8eb85ab144262ee3fe3a46d --- /dev/null +++ b/utils/postprocessing/hole_filling.py @@ -0,0 +1,32 @@ +import cv2 +import numpy as np +from scipy.ndimage.measurements import label + + +def fill_holes(img, threshold, rate): + binary_img = np.where(img > threshold, 0, 1) #reversed image + structure = np.ones((3, 3, 3), dtype=np.int) + labeled, ncomponents = label(binary_img, structure) + # print(labeled.shape, ncomponents) + count_list = [] + #count + for pixel_val in range(ncomponents): + count = 0 + for y in range(labeled.shape[1]): + for x in range(labeled.shape[0]): + if labeled[x][y][0] == pixel_val + 1: + count += 1 + count_list.append(count) + # print(count_list) + + for i in range(len(count_list)): + # print(i) + if sum(count_list) != 0: + if count_list[i] / sum(count_list) < rate: + for y in range(labeled.shape[1]): + for x in range(labeled.shape[0]): + if labeled[x][y][0] == i + 1: + labeled[x][y] = [0,0,0] + labeled = np.where(labeled < 1, 1, 0) + labeled *= 255 + return labeled \ No newline at end of file diff --git a/utils/postprocessing/remove_small_noise.py b/utils/postprocessing/remove_small_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..7d32a8b6dc61960b92937a1d014f099082acc4b0 --- /dev/null +++ b/utils/postprocessing/remove_small_noise.py @@ -0,0 +1,31 @@ +import cv2 +import numpy as np +from scipy.ndimage.measurements import label + + +def remove_small_areas(img, threshold, rate): + structure = np.ones((3, 3, 3), dtype=np.int) + labeled, ncomponents = label(img, structure) + # print(labeled.shape, ncomponents) + count_list = [] + # count + for pixel_val in range(ncomponents): + count = 0 + for y in range(labeled.shape[1]): + for x in range(labeled.shape[0]): + if labeled[x][y][0] == pixel_val + 1: + count += 1 + count_list.append(count) + # print(count_list) + + for i in range(len(count_list)): + # print(i) + if sum(count_list) != 0: + if count_list[i] / sum(count_list) < rate: + for y in range(labeled.shape[1]): + for x in range(labeled.shape[0]): + if labeled[x][y][0] == i + 1: + labeled[x][y] = [0, 0, 0] + labeled = np.where(labeled < 1, 0, 1) + labeled *= 255 + return labeled \ No newline at end of file diff --git a/utils/postprocessing/threshold.py b/utils/postprocessing/threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8adb7d1c8b8370ba73b639424919d77492537b --- /dev/null +++ b/utils/postprocessing/threshold.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------ # +# +# file : postprocessing/threshold.py +# author : CM +# Segment image with threshold +# +# ------------------------------------------------------------ # diff --git a/utils/preprocessing/normalisation.py b/utils/preprocessing/normalisation.py new file mode 100644 index 0000000000000000000000000000000000000000..6849fec43c870a920ea4882a86f968d1afa8f714 --- /dev/null +++ b/utils/preprocessing/normalisation.py @@ -0,0 +1,34 @@ +# ------------------------------------------------------------ # +# +# file : preprocessing/normalisation.py +# author : CM +# +# ------------------------------------------------------------ # +import numpy as np + +# Rescaling (min-max normalization) +def linear_intensity_normalization(loaded_dataset): + loaded_dataset = (loaded_dataset / loaded_dataset.max()) + return loaded_dataset + +# Preprocess dataset with intensity normalisation +# (zero mean and unit variance) +def standardization_intensity_normalization(dataset, dtype): + mean = dataset.mean() + std = dataset.std() + return ((dataset - mean) / std).astype(dtype) + +# Intensities normalized to the range [0, 1] +def intensityNormalisationFeatureScaling(dataset, dtype): + max = dataset.max() + min = dataset.min() + + return ((dataset - min) / (max - min)).astype(dtype) + +# Intensity max clipping with c "max value" +def intensityMaxClipping(dataset, c, dtype): + return np.clip(a=dataset, a_min=0, a_max=c).astype(dtype) + +# Intensity projection +def intensityProjection(dataset, p, dtype): + return (dataset ** p).astype(dtype) \ No newline at end of file diff --git a/utils/preprocessing/threshold.py b/utils/preprocessing/threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..f2fcb0e22a7d7af0b88c69f29df6cfe640b762c1 --- /dev/null +++ b/utils/preprocessing/threshold.py @@ -0,0 +1,109 @@ +# ------------------------------------------------------------ # +# +# file : preprocessing/threshold.py +# author : CM +# Preprocess function for Bullitt dataset +# +# ------------------------------------------------------------ # +import os +import sys +import numpy as np +import nibabel as nib +from utils.io.write import npToNii +from utils.config.read import readConfig + +# Get the threshold for preprocessing +def getThreshold(dataset_mra, dataset_gd): + threshold = dataset_mra.max() + + for i in range(0,len(dataset_mra)): + mra = dataset_mra[i] + gd = dataset_gd[i] + + for x in range(0, mra.shape[0]): + for y in range(0, mra.shape[1]): + for z in range(0, mra.shape[2]): + if(gd[x,y,z] == 1 and mra[x,y,z] < threshold): + threshold = mra[x,y,z] + + return threshold + +# Apply threshold to an image +def thresholding(data, threshold): + output = np.copy(data) + for x in range(0, data.shape[0]): + for y in range(0, data.shape[1]): + for z in range(0, data.shape[2]): + if data[x,y,z] > threshold: + output[x,y,z] = data[x,y,z] + else: + output[x,y,z] = 0 + return output + +config_filename = sys.argv[1] +if(not os.path.isfile(config_filename)): + sys.exit(1) + +config = readConfig(config_filename) + +output_folder = sys.argv[2] +if(not os.path.isdir(output_folder)): + sys.exit(1) + +print("Loading training dataset") + +train_mra_dataset = np.empty((30, config["image_size_x"], config["image_size_y"], config["image_size_z"])) +i = 0 +files = os.listdir(config["dataset_train_mra_path"]) +files.sort() + +for filename in files: + if(i>=30): + break + print(filename) + train_mra_dataset[i, :, :, :] = nib.load(os.path.join(config["dataset_train_mra_path"], filename)).get_data() + i = i + 1 + + +train_gd_dataset = np.empty((30, config["image_size_x"], config["image_size_y"], config["image_size_z"])) +i = 0 +files = os.listdir(config["dataset_train_gd_path"]) +files.sort() + +for filename in files: + if(i>=30): + break + print(filename) + train_gd_dataset[i, :, :, :] = nib.load(os.path.join(config["dataset_train_gd_path"], filename)).get_data() + i = i + 1 + +print("Compute threshold") +threshold = getThreshold(train_mra_dataset, train_gd_dataset) + +train_mra_dataset = None +train_gd_dataset = None + +print("Apply preprocessing to test image") +files = os.listdir(config["dataset_test_mra_path"]) +files.sort() + +for filename in files: + print(filename) + data = nib.load(os.path.join(config["dataset_test_mra_path"], filename)).get_data() + print(np.average(data)) + preprocessed = thresholding(data, threshold) + print(np.average(preprocessed)) + npToNii(preprocessed,os.path.join(output_folder+'/test_Images', 'pre_'+filename)) + + +print("Apply threshold to train image : ", threshold) +files = os.listdir(config["dataset_train_mra_path"]) +files.sort() + +for filename in files: + print(filename) + data = nib.load(os.path.join(config["dataset_train_mra_path"], filename)).get_data() + print(np.average(data)) + preprocessed = thresholding(data, threshold) + print(np.average(preprocessed)) + npToNii(preprocessed,os.path.join(output_folder+'/train_Images', 'pre_'+filename)) \ No newline at end of file