Spaces:
Runtime error
Runtime error
| import argparse | |
| import csv | |
| import io | |
| import os | |
| import zipfile | |
| from pathlib import Path | |
| from typing import Tuple, Dict | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import timm | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Polygon | |
| try: | |
| from src.data_collection import DataCollectionManager, classify_from_percentages_simple | |
| except ImportError: | |
| import sys | |
| sys.path.insert(0, str(Path(__file__).resolve().parent / "src")) | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| from data_collection import DataCollectionManager, classify_from_percentages_simple | |
| # ============================================================================ | |
| # MODEL ARCHITECTURE (Embedded) | |
| # ============================================================================ | |
| class IdentityAttention(nn.Module): | |
| """No-op attention block.""" | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | |
| class SEFeatureAttention(nn.Module): | |
| """Squeeze-and-Excitation style attention for vector features.""" | |
| def __init__(self, feature_dim: int, reduction: int = 16): | |
| super().__init__() | |
| hidden_dim = max(8, feature_dim // reduction) | |
| self.fc = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, feature_dim), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * self.fc(x) | |
| class CBAMFeatureAttention(nn.Module): | |
| """CBAM-inspired attention for vector features.""" | |
| def __init__(self, feature_dim: int, reduction: int = 16): | |
| super().__init__() | |
| hidden_dim = max(8, feature_dim // reduction) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, feature_dim), | |
| ) | |
| self.gate = nn.Sigmoid() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| avg_desc = self.mlp(x) | |
| max_pool = x.max(dim=1, keepdim=True).values.expand_as(x) | |
| max_desc = self.mlp(max_pool) | |
| return x * self.gate(avg_desc + max_desc) | |
| def build_attention_block(attention_type: str, feature_dim: int, reduction: int = 16) -> nn.Module: | |
| key = (attention_type or "none").lower() | |
| if key == "none": | |
| return IdentityAttention() | |
| if key == "se": | |
| return SEFeatureAttention(feature_dim=feature_dim, reduction=reduction) | |
| if key == "cbam": | |
| return CBAMFeatureAttention(feature_dim=feature_dim, reduction=reduction) | |
| raise ValueError(f"Unknown attention type: {attention_type}") | |
| class SoilTextureModel(nn.Module): | |
| """ | |
| Multi-task model for soil texture analysis. | |
| Architecture: | |
| Image -> Backbone -> Shared Features -> Classification Head -> Texture Class | |
| -> Regression Head -> [Sand%, Silt%, Clay%] | |
| """ | |
| BACKBONE_CONFIGS = { | |
| 'efficientnet_v2_s': {'feature_dim': 1280, 'pretrained': 'tf_efficientnetv2_s'}, | |
| 'convnext_tiny': {'feature_dim': 768, 'pretrained': 'convnext_tiny'}, | |
| 'mobilevit_s': {'feature_dim': 640, 'pretrained': 'mobilevit_s'}, | |
| 'swin_tiny': {'feature_dim': 768, 'pretrained': 'swin_tiny_patch4_window7_224'}, | |
| 'resnet50': {'feature_dim': 2048, 'pretrained': 'resnet50'}, | |
| } | |
| def __init__( | |
| self, | |
| backbone_name: str = 'efficientnet_v2_s', | |
| num_classes: int = 12, | |
| dropout: float = 0.3, | |
| pretrained: bool = True, | |
| freeze_backbone: bool = False, | |
| attention_type: str = "none", | |
| attention_reduction: int = 16, | |
| task_attention: bool = False, | |
| ): | |
| super().__init__() | |
| self.backbone_name = backbone_name | |
| self.num_classes = num_classes | |
| # Get backbone configuration | |
| config = self.BACKBONE_CONFIGS.get(backbone_name, self.BACKBONE_CONFIGS['efficientnet_v2_s']) | |
| feature_dim = config['feature_dim'] | |
| # Load pretrained backbone | |
| self.backbone = timm.create_model( | |
| config['pretrained'], | |
| pretrained=pretrained, | |
| num_classes=0, # Remove classifier head | |
| global_pool='avg' | |
| ) | |
| # Freeze backbone if specified | |
| if freeze_backbone: | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| self.shared_attention = build_attention_block( | |
| attention_type=attention_type, | |
| feature_dim=feature_dim, | |
| reduction=attention_reduction, | |
| ) | |
| if task_attention: | |
| self.class_attention = build_attention_block( | |
| attention_type=attention_type, | |
| feature_dim=feature_dim, | |
| reduction=attention_reduction, | |
| ) | |
| self.reg_attention = build_attention_block( | |
| attention_type=attention_type, | |
| feature_dim=feature_dim, | |
| reduction=attention_reduction, | |
| ) | |
| else: | |
| self.class_attention = IdentityAttention() | |
| self.reg_attention = IdentityAttention() | |
| # Classification head (texture type) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(feature_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout * 0.5), | |
| nn.Linear(512, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(256, num_classes) | |
| ) | |
| # Regression head (Sand, Silt, Clay percentages) | |
| self.regressor = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(feature_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout * 0.5), | |
| nn.Linear(512, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(256, 3) # Sand, Silt, Clay | |
| ) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| for m in [ | |
| self.shared_attention, | |
| self.class_attention, | |
| self.reg_attention, | |
| self.classifier, | |
| self.regressor, | |
| ]: | |
| for layer in m.modules(): | |
| if isinstance(layer, nn.Linear): | |
| nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') | |
| if layer.bias is not None: | |
| nn.init.constant_(layer.bias, 0) | |
| elif isinstance(layer, nn.BatchNorm1d): | |
| nn.init.constant_(layer.weight, 1) | |
| nn.init.constant_(layer.bias, 0) | |
| def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]: | |
| """Forward pass.""" | |
| # Extract features | |
| features = self.backbone(x) | |
| features = self.shared_attention(features) | |
| cls_features = self.class_attention(features) | |
| reg_features = self.reg_attention(features) | |
| # Classification | |
| class_logits = self.classifier(cls_features) | |
| # Regression (with softmax to ensure sum = 100) | |
| reg_output = self.regressor(reg_features) | |
| concentrations = F.softmax(reg_output, dim=1) * 100 # Scale to percentages | |
| result = { | |
| 'class_logits': class_logits, | |
| 'concentrations': concentrations | |
| } | |
| if return_features: | |
| result['features'] = features | |
| return result | |
| def create_model( | |
| model_type: str = 'full', | |
| backbone: str = 'efficientnet_v2_s', | |
| num_classes: int = 12, | |
| pretrained: bool = True, | |
| attention_type: str = "none", | |
| attention_reduction: int = 16, | |
| task_attention: bool = False, | |
| ) -> nn.Module: | |
| """Factory function to create model.""" | |
| model = SoilTextureModel( | |
| backbone_name=backbone, | |
| num_classes=num_classes, | |
| pretrained=pretrained, | |
| attention_type=attention_type, | |
| attention_reduction=attention_reduction, | |
| task_attention=task_attention, | |
| ) | |
| return model | |
| def format_prediction_markdown(result: Dict) -> str: | |
| """Create markdown output for inference results.""" | |
| sorted_probs = sorted(result["class_probabilities"].items(), key=lambda x: x[1], reverse=True) | |
| lines = [ | |
| "### Prediction Result", | |
| f"- **Texture Class:** `{result['class']}`", | |
| f"- **Confidence:** `{result['confidence'] * 100:.2f}%`", | |
| f"- **Sand / Silt / Clay:** `{result['sand']:.2f}% / {result['silt']:.2f}% / {result['clay']:.2f}%`", | |
| "", | |
| "**Top Probabilities**", | |
| ] | |
| for class_name, prob in sorted_probs[:5]: | |
| lines.append(f"- {class_name}: {prob * 100:.2f}%") | |
| return "\n".join(lines) | |
| # ============================================================================ | |
| # SOIL TEXTURE TRIANGLE VISUALIZATION | |
| # ============================================================================ | |
| def create_texture_triangle(sand: float, silt: float, clay: float, predicted_class: str, | |
| confidence: float = None, top_probs: list = None) -> np.ndarray: | |
| """ | |
| Create USDA Soil Texture Triangle visualization with correct boundaries. | |
| """ | |
| fig, ax = plt.subplots(1, 1, figsize=(14, 12), facecolor='white', dpi=150) | |
| # Helper function to convert soil percentages to triangle coordinates | |
| def soil_to_coords(sand_pct, silt_pct, clay_pct): | |
| x = silt_pct/100 + clay_pct/200 | |
| y = clay_pct/100 * np.sqrt(3)/2 | |
| return x, y | |
| # USDA Soil Texture Triangle regions with correct boundaries | |
| regions = [ | |
| ('Sand', [(100, 0, 0), (85, 15, 0), (90, 0, 10)], '#FFE4B5'), | |
| ('Loamy Sand', [(85, 15, 0), (70, 30, 0), (85, 0, 15), (90, 0, 10)], '#FFDAB9'), | |
| ('Sandy Loam', [(70, 30, 0), (50, 50, 0), (42.5, 50, 7.5), (52.5, 40, 7.5), (52.5, 27.5, 20), (80, 0, 20), (85, 0, 15)], '#F4A460'), | |
| ('Loam', [(42.5, 50, 7.5), (22.5, 50, 27.5), (45, 27.5, 27.5), (52.5, 27.5, 20), (52.5, 40, 7.5)], '#DEB887'), | |
| ('Silt Loam', [(50, 50, 0), (20, 80, 0), (7.5, 80, 12.5), (0, 87.5, 12.5), (0, 72.5, 27.5), (22.5, 50, 27.5)], '#D2B48C'), | |
| ('Silt', [(20, 80, 0), (0, 100, 0), (0, 87.5, 12.5), (7.5, 80, 12.5)], '#C0C0C0'), | |
| ('Sandy Clay Loam', [(80, 0, 20), (52.5, 27.5, 20), (45, 27.5, 27.5), (45, 20, 35), (65, 0, 35)], '#CD853F'), | |
| ('Clay Loam', [(45, 27.5, 27.5), (20, 52.5, 27.5), (20, 40, 40), (45, 15, 40)], '#D2691E'), | |
| ('Silty Clay Loam', [(0, 72.5, 27.5), (0, 60, 40), (20, 40, 40), (20, 52.5, 27.5)], '#B8860B'), | |
| ('Sandy Clay', [(65, 0, 35), (45, 20, 35), (45, 0, 55)], '#A0522D'), | |
| ('Silty Clay', [(20, 40, 40), (0, 60, 40), (0, 40, 60)], '#8B4513'), | |
| ('Clay', [(45, 15, 40), (20, 40, 40), (0, 40, 60), (0, 0, 100), (45, 0, 55)], '#654321'), | |
| ] | |
| # Draw colored regions with border lines | |
| for name, vertices_pct, color in regions: | |
| vertices_xy = [soil_to_coords(s, si, c) for s, si, c in vertices_pct] | |
| region_patch = Polygon(vertices_xy, facecolor=color, edgecolor='#333', | |
| linewidth=1.2, alpha=0.8, zorder=1) | |
| ax.add_patch(region_patch) | |
| # Add label | |
| center_x = np.mean([v[0] for v in vertices_xy]) | |
| center_y = np.mean([v[1] for v in vertices_xy]) | |
| ax.text(center_x, center_y, name, fontsize=12, ha='center', | |
| va='center', weight='bold', zorder=2) | |
| # Draw triangle outline | |
| triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]]) | |
| tri_patch = Polygon(triangle, fill=False, edgecolor='black', linewidth=4, zorder=3) | |
| ax.add_patch(tri_patch) | |
| # Add corner labels | |
| ax.text(0, -0.05, '100% Sand', fontsize=16, ha='center', weight='bold') | |
| ax.text(1, -0.05, '100% Silt', fontsize=16, ha='center', weight='bold') | |
| ax.text(0.5, np.sqrt(3)/2 + 0.03, '100% Clay', fontsize=16, ha='center', weight='bold') | |
| # Add grid lines | |
| for pct in range(5, 100, 5): | |
| y = pct/100 * np.sqrt(3)/2 | |
| x_left = pct/200 | |
| x_right = 1 - pct/200 | |
| sand_pct = pct | |
| p1 = soil_to_coords(sand_pct, 0, 100-sand_pct) | |
| p2 = soil_to_coords(sand_pct, 100-sand_pct, 0) | |
| silt_pct = pct | |
| p3 = soil_to_coords(0, silt_pct, 100-silt_pct) | |
| p4 = soil_to_coords(100-silt_pct, silt_pct, 0) | |
| if pct % 10 == 0: | |
| ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.3, linewidth=1.0, zorder=0) | |
| ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0) | |
| ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0) | |
| ax.text(x_left - 0.03, y, f'{pct}', fontsize=11, alpha=0.7, weight='bold') | |
| else: | |
| ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.15, linewidth=0.6, zorder=0) | |
| ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0) | |
| ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0) | |
| # Plot prediction point | |
| pred_x, pred_y = soil_to_coords(sand, silt, clay) | |
| ax.plot(pred_x, pred_y, 'o', markersize=22, markerfacecolor='red', | |
| markeredgecolor='darkred', markeredgewidth=3.5, zorder=5) | |
| # Add annotation | |
| offset_x = 0.15 if pred_x < 0.7 else -0.15 | |
| offset_y = 0.08 | |
| ax.annotate(f'{predicted_class}\n({sand:.0f}%, {silt:.0f}%, {clay:.0f}%)', | |
| xy=(pred_x, pred_y), xytext=(pred_x + offset_x, pred_y + offset_y), | |
| fontsize=14, fontweight='bold', | |
| arrowprops=dict(arrowstyle='->', lw=2.5, color='darkred'), | |
| bbox=dict(boxstyle='round,pad=0.6', facecolor='white', edgecolor='darkred', lw=2.5), | |
| ha='center', zorder=6) | |
| # Add prediction information boxes | |
| if confidence is not None and top_probs is not None: | |
| # Left box - Prediction and Composition | |
| left_text = f"Predicted Class:\n{predicted_class}\n\n" | |
| left_text += f"Confidence: {confidence*100:.1f}%\n\n" | |
| left_text += f"Composition:\n" | |
| left_text += f"Sand: {sand:.1f}%\n" | |
| left_text += f"Silt: {silt:.1f}%\n" | |
| left_text += f"Clay: {clay:.1f}%" | |
| ax.text(0.05, 0.82, left_text, | |
| fontsize=16, verticalalignment='top', | |
| bbox=dict(boxstyle='round,pad=0.9', facecolor='white', | |
| edgecolor='black', linewidth=2.5, alpha=0.95), | |
| zorder=7, family='monospace', weight='bold') | |
| # Right box - Top 5 Probabilities | |
| right_text = "Top 5 Probabilities:\n\n" | |
| for i, (cls, prob) in enumerate(top_probs[:5], 1): | |
| right_text += f"{i}. {cls}: {prob*100:.1f}%\n" | |
| ax.text(0.75, 0.82, right_text, | |
| fontsize=16, verticalalignment='top', | |
| bbox=dict(boxstyle='round,pad=0.9', facecolor='white', | |
| edgecolor='black', linewidth=2.5, alpha=0.95), | |
| zorder=7, family='monospace', weight='bold') | |
| ax.set_xlim(-0.08, 1.08) | |
| ax.set_ylim(-0.08, np.sqrt(3)/2 + 0.06) | |
| ax.set_aspect('equal') | |
| ax.axis('off') | |
| ax.set_title('USDA Soil Texture Triangle', fontsize=20, fontweight='bold', pad=8) | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
| img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| img = img[:, :, :3] | |
| plt.close(fig) | |
| return img | |
| # ============================================================================ | |
| # PREDICTOR CLASS | |
| # ============================================================================ | |
| def classify_from_percentages(sand: float, silt: float, clay: float) -> str: | |
| """ | |
| Determine USDA texture class from Sand/Silt/Clay percentages. | |
| Uses official USDA classification boundaries. | |
| """ | |
| # Normalize to ensure sum = 100 | |
| total = sand + silt + clay | |
| if total > 0: | |
| sand = sand / total * 100 | |
| silt = silt / total * 100 | |
| clay = clay / total * 100 | |
| # USDA classification rules (order matters for overlapping boundaries) | |
| if clay >= 40: | |
| if silt >= 40: | |
| return 'Silty Clay' | |
| elif sand >= 45: | |
| return 'Sandy Clay' | |
| else: | |
| return 'Clay' | |
| elif clay >= 35: | |
| if sand >= 45: | |
| return 'Sandy Clay' | |
| elif silt < 20: | |
| return 'Sandy Clay' | |
| else: | |
| return 'Clay Loam' | |
| elif clay >= 27: | |
| if sand >= 20 and sand < 45: | |
| return 'Clay Loam' | |
| elif silt >= 28 and silt < 40: | |
| return 'Clay Loam' | |
| elif silt >= 40: | |
| return 'Silty Clay Loam' | |
| else: | |
| return 'Sandy Clay Loam' | |
| elif clay >= 20: | |
| if sand >= 45: | |
| return 'Sandy Clay Loam' | |
| elif silt >= 28 and sand < 45: | |
| return 'Clay Loam' | |
| elif silt >= 50: | |
| return 'Silty Clay Loam' | |
| else: | |
| return 'Sandy Clay Loam' | |
| elif clay >= 12: | |
| if silt >= 50 and clay >= 12 and clay < 27: | |
| return 'Silt Loam' | |
| elif silt >= 50 and silt < 80: | |
| return 'Silt Loam' | |
| elif silt >= 80 and clay < 12: | |
| return 'Silt' | |
| elif sand >= 52: | |
| return 'Sandy Loam' | |
| else: | |
| return 'Loam' | |
| elif clay >= 7: | |
| if silt >= 50: | |
| return 'Silt Loam' | |
| elif silt >= 28 and silt < 50 and sand < 52: | |
| return 'Loam' | |
| else: | |
| return 'Sandy Loam' | |
| else: | |
| # clay < 7 | |
| if silt >= 80: | |
| return 'Silt' | |
| elif silt >= 50: | |
| return 'Silt Loam' | |
| elif sand >= 85 and silt + 1.5 * clay < 15: | |
| return 'Sand' | |
| elif sand >= 70 and sand < 85: | |
| return 'Loamy Sand' | |
| elif sand >= 43 and sand < 52: | |
| return 'Sandy Loam' if silt < 50 else 'Silt Loam' | |
| elif sand >= 52: | |
| return 'Sandy Loam' | |
| else: | |
| return 'Loam' | |
| class SoilTexturePredictor: | |
| """ | |
| Inference wrapper for soil texture prediction. | |
| """ | |
| CLASSES = [ | |
| 'Sand', 'Loamy Sand', 'Sandy Loam', 'Loam', 'Silt Loam', 'Silt', | |
| 'Sandy Clay Loam', 'Clay Loam', 'Silty Clay Loam', 'Sandy Clay', 'Silty Clay', 'Clay' | |
| ] | |
| def __init__( | |
| self, | |
| checkpoint_path: str = None, | |
| device: str = None, | |
| attention_type: str = "none", | |
| attention_reduction: int = 16, | |
| task_attention: bool = False, | |
| ): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Create model | |
| self.model = create_model( | |
| model_type='full', | |
| backbone='efficientnet_v2_s', | |
| num_classes=len(self.CLASSES), | |
| pretrained=False, | |
| attention_type=attention_type, | |
| attention_reduction=attention_reduction, | |
| task_attention=task_attention, | |
| ) | |
| # Load checkpoint if provided | |
| if checkpoint_path and Path(checkpoint_path).exists(): | |
| print(f"Loading checkpoint: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) | |
| if 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| else: | |
| print("No checkpoint provided, using random weights (for demo)") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Transform | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((500, 500)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(self, image: Image.Image) -> Dict: | |
| """ | |
| Predict soil texture class and concentrations. | |
| """ | |
| # Preprocess | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Forward pass | |
| output = self.model(img_tensor) | |
| # Class prediction from classification head (for reference) | |
| class_probs = F.softmax(output['class_logits'], dim=1).cpu().numpy()[0] | |
| # Concentration prediction | |
| concentrations = output['concentrations'].cpu().numpy()[0] | |
| sand, silt, clay = concentrations | |
| # Ensure they sum to 100 | |
| total = sand + silt + clay | |
| sand = sand / total * 100 | |
| silt = silt / total * 100 | |
| clay = clay / total * 100 | |
| # Derive class from percentages to ensure consistency | |
| class_name = classify_from_percentages(sand, silt, clay) | |
| confidence = class_probs[self.CLASSES.index(class_name)] | |
| return { | |
| 'class': class_name, | |
| 'confidence': confidence, | |
| 'class_probabilities': {self.CLASSES[i]: float(p) for i, p in enumerate(class_probs)}, | |
| 'sand': sand, | |
| 'silt': silt, | |
| 'clay': clay | |
| } | |
| def predict_with_visualization(self, image: Image.Image) -> Tuple[str, np.ndarray, Dict]: | |
| """Predict and create visualization.""" | |
| result = self.predict(image) | |
| # Sort by probability and show top 5 | |
| sorted_probs = sorted(result['class_probabilities'].items(), key=lambda x: x[1], reverse=True)[:5] | |
| # Create texture triangle | |
| triangle_img = create_texture_triangle( | |
| result['sand'], result['silt'], result['clay'], result['class'], | |
| confidence=result['confidence'], | |
| top_probs=sorted_probs | |
| ) | |
| text_output = format_prediction_markdown(result) | |
| return text_output, triangle_img, result | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| def create_demo( | |
| checkpoint_path: str = None, | |
| attention_type: str = "none", | |
| attention_reduction: int = 16, | |
| task_attention: bool = False, | |
| ): | |
| """Create Gradio demo interface.""" | |
| # Initialize predictor | |
| predictor = SoilTexturePredictor( | |
| checkpoint_path=checkpoint_path, | |
| attention_type=attention_type, | |
| attention_reduction=attention_reduction, | |
| task_attention=task_attention, | |
| ) | |
| collection_manager = DataCollectionManager() | |
| collection_manager.ensure_storage() | |
| collection_manager.start_scheduler() | |
| def to_pil_image(image): | |
| """Convert possible Gradio image input to PIL.""" | |
| if isinstance(image, Image.Image): | |
| return image.convert("RGB") | |
| if isinstance(image, np.ndarray): | |
| return Image.fromarray(image).convert("RGB") | |
| raise ValueError("Unsupported image format.") | |
| def predict_fn(image): | |
| """Gradio prediction function.""" | |
| if image is None: | |
| return "Please upload an image.", None | |
| image = to_pil_image(image) | |
| # Get prediction | |
| text_output, triangle_img, _ = predictor.predict_with_visualization(image) | |
| return text_output, triangle_img | |
| def submit_contribution_fn( | |
| image, | |
| sand, | |
| silt, | |
| clay, | |
| weak_label, | |
| strong_label, | |
| sample_source, | |
| location, | |
| notes, | |
| consent | |
| ): | |
| """Persist user-contributed image + composition for future training.""" | |
| if image is None: | |
| return "Submission failed: please upload a soil image." | |
| image = to_pil_image(image) | |
| validation = collection_manager.validate_submission( | |
| sand=sand, | |
| silt=silt, | |
| clay=clay, | |
| consent=consent, | |
| image=image, | |
| ) | |
| if not validation.ok: | |
| return f"Submission failed: {validation.message}" | |
| prediction = predictor.predict(image) | |
| user_class = classify_from_percentages_simple(sand, silt, clay) | |
| submission_id = collection_manager.create_submission_id() | |
| save_result = collection_manager.save_submission( | |
| image=image, | |
| submission_id=submission_id, | |
| sand=sand, | |
| silt=silt, | |
| clay=clay, | |
| user_class=user_class, | |
| weak_label=weak_label, | |
| strong_label=strong_label, | |
| prediction=prediction, | |
| sample_source=sample_source, | |
| location=location, | |
| notes=notes, | |
| total=validation.total, | |
| ) | |
| image_path = save_result.get("image_path", "") | |
| is_duplicate = save_result.get("is_duplicate", "0") == "1" | |
| duplicate_of_submission = save_result.get("duplicate_of_submission", "") | |
| export_bundles = collection_manager.maybe_trigger_exports() | |
| export_note = "" | |
| if export_bundles: | |
| export_note = "\n- Auto-export triggered:\n" + "\n".join([f" - `{bundle}`" for bundle in export_bundles]) | |
| dedup_note = "" | |
| if is_duplicate: | |
| dedup_note = f"\n- Duplicate image detected. Reused existing sample from `{duplicate_of_submission}`." | |
| return ( | |
| "### Submission Saved\n" | |
| f"- Submission ID: `{submission_id}`\n" | |
| f"- Stored image: `{image_path}`\n" | |
| f"- User label class: `{user_class}`\n" | |
| f"- Model prediction: `{prediction['class']}` ({prediction['confidence'] * 100:.2f}%)\n" | |
| f"- Weak label: `{weak_label or ''}`\n" | |
| f"- Strong label: `{strong_label or ''}`\n" | |
| "- Data was appended to `data/community_submissions/submissions.csv`.\n" | |
| "- Daily export uses background scheduler; high disk usage triggers immediate export." | |
| f"{dedup_note}" | |
| f"{export_note}" | |
| ) | |
| def get_dataset_stats_fn(): | |
| """Get statistics about the current dataset.""" | |
| cfg = collection_manager.config | |
| num_submissions = 0 | |
| if cfg.csv_path.exists(): | |
| with cfg.csv_path.open("r", encoding="utf-8") as f: | |
| reader = csv.reader(f) | |
| next(reader, None) | |
| num_submissions = sum(1 for _ in reader) | |
| num_images = 0 | |
| total_size_bytes = 0 | |
| if cfg.images_dir.exists(): | |
| for p in cfg.images_dir.iterdir(): | |
| if p.is_file(): | |
| num_images += 1 | |
| total_size_bytes += p.stat().st_size | |
| total_size_mb = total_size_bytes / (1024 * 1024) | |
| return ( | |
| f"### Dataset Statistics\n" | |
| f"- **Total submissions:** {num_submissions}\n" | |
| f"- **Total images:** {num_images}\n" | |
| f"- **Total image size:** {total_size_mb:.1f} MB\n" | |
| ) | |
| def upload_dataset_fn(zip_file, upload_consent): | |
| """Process uploaded ZIP dataset with images and CSV.""" | |
| if zip_file is None: | |
| return "Please upload a ZIP file." | |
| if not upload_consent: | |
| return "Please confirm consent before uploading." | |
| zip_path = zip_file if isinstance(zip_file, str) else zip_file.name | |
| if not zipfile.is_zipfile(zip_path): | |
| return "Invalid ZIP file." | |
| max_entries = 10000 | |
| max_total_size = 500 * 1024 * 1024 | |
| results = {"added": 0, "skipped": 0, "errors": []} | |
| try: | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| entries = zf.infolist() | |
| if len(entries) > max_entries: | |
| return f"ZIP has too many entries ({len(entries)}). Max: {max_entries}." | |
| total_size = sum(e.file_size for e in entries) | |
| if total_size > max_total_size: | |
| return f"ZIP too large ({total_size / 1024 / 1024:.0f} MB). Max: {max_total_size // (1024 * 1024)} MB." | |
| csv_entries = [ | |
| e for e in entries | |
| if e.filename.endswith(".csv") and not e.filename.startswith("__") | |
| ] | |
| if not csv_entries: | |
| return "No CSV found in ZIP. Expected CSV with columns: filename, sand, silt, clay." | |
| with zf.open(csv_entries[0]) as csv_file: | |
| content = csv_file.read().decode("utf-8") | |
| reader = csv.DictReader(io.StringIO(content)) | |
| headers = set(reader.fieldnames or []) | |
| required = {"filename", "sand", "silt", "clay"} | |
| if not required.issubset(headers): | |
| return ( | |
| f"CSV must have columns: {', '.join(sorted(required))}. " | |
| f"Found: {', '.join(sorted(headers))}" | |
| ) | |
| for row in reader: | |
| try: | |
| fname = row["filename"].strip() | |
| sand = float(row["sand"]) | |
| silt = float(row["silt"]) | |
| clay = float(row["clay"]) | |
| vals = [sand, silt, clay] | |
| if any(v < 0 or v > 100 for v in vals): | |
| results["errors"].append(f"{fname}: values out of range") | |
| results["skipped"] += 1 | |
| continue | |
| total = sand + silt + clay | |
| if abs(total - 100.0) > 1.0: | |
| results["errors"].append(f"{fname}: sum={total:.1f}, must be ~100") | |
| results["skipped"] += 1 | |
| continue | |
| matches = [e for e in entries if Path(e.filename).name == fname] | |
| if not matches: | |
| results["errors"].append(f"Image not found in ZIP: {fname}") | |
| results["skipped"] += 1 | |
| continue | |
| with zf.open(matches[0]) as img_bytes: | |
| image = Image.open(img_bytes).convert("RGB") | |
| if image.width * image.height > collection_manager.config.max_image_pixels: | |
| results["errors"].append(f"{fname}: image too large") | |
| results["skipped"] += 1 | |
| continue | |
| prediction = predictor.predict(image) | |
| user_class = classify_from_percentages_simple(sand, silt, clay) | |
| submission_id = collection_manager.create_submission_id() | |
| collection_manager.save_submission( | |
| image=image, | |
| submission_id=submission_id, | |
| sand=sand, silt=silt, clay=clay, | |
| user_class=user_class, | |
| weak_label=row.get("weak_label", ""), | |
| strong_label=row.get("strong_label", ""), | |
| prediction=prediction, | |
| sample_source=row.get("source", ""), | |
| location=row.get("location", ""), | |
| notes=row.get("notes", ""), | |
| total=total, | |
| ) | |
| results["added"] += 1 | |
| except Exception as e: | |
| results["errors"].append(f"{row.get('filename', '?')}: {e}") | |
| results["skipped"] += 1 | |
| except Exception as e: | |
| return f"Failed to process ZIP: {e}" | |
| error_summary = "" | |
| if results["errors"]: | |
| shown = results["errors"][:20] | |
| error_summary = "\n\n**Errors:**\n" + "\n".join(f"- {e}" for e in shown) | |
| if len(results["errors"]) > 20: | |
| error_summary += f"\n- ... and {len(results['errors']) - 20} more" | |
| return ( | |
| f"### Upload Complete\n" | |
| f"- **Added:** {results['added']} submissions\n" | |
| f"- **Skipped:** {results['skipped']}\n" | |
| f"{error_summary}" | |
| ) | |
| # Create interface | |
| with gr.Blocks(title="Soil Texture Classifier") as demo: | |
| gr.Markdown(""" | |
| # Soil Texture Classification | |
| 1. Use **Inference** to predict texture class and composition from image. | |
| 2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training. | |
| 3. Use **Dataset Management** to bulk-upload a ZIP dataset for model improvement. | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Inference"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Soil Image", type="pil") | |
| predict_btn = gr.Button("Analyze", variant="primary") | |
| gr.Markdown(""" | |
| **Tips:** | |
| - Use close-up images of soil surface | |
| - Ensure good lighting | |
| - Avoid shadows and reflections | |
| """) | |
| with gr.Column(): | |
| output_text = gr.Markdown(label="Results") | |
| output_triangle = gr.Image(label="USDA Texture Triangle") | |
| with gr.Tab("Contribute Data"): | |
| gr.Markdown(""" | |
| Upload a soil image with measured Sand/Silt/Clay percentages. | |
| This data will be stored for manual quality checks and future retraining. | |
| You can optionally submit weak/strong labels for better curation quality. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| contribution_image = gr.Image(label="Soil Image for Contribution", type="pil") | |
| weak_label = gr.Dropdown( | |
| choices=[""] + SoilTexturePredictor.CLASSES, | |
| value="", | |
| allow_custom_value=True, | |
| label="Weak Label (Optional)" | |
| ) | |
| strong_label = gr.Dropdown( | |
| choices=[""] + SoilTexturePredictor.CLASSES, | |
| value="", | |
| allow_custom_value=True, | |
| label="Strong Label (Optional)" | |
| ) | |
| sample_source = gr.Textbox( | |
| label="Sample Source", | |
| placeholder="e.g., field site, experiment ID, sample batch" | |
| ) | |
| location = gr.Textbox( | |
| label="Location (Optional)", | |
| placeholder="e.g., Iowa, USA" | |
| ) | |
| notes = gr.Textbox( | |
| label="Notes (Optional)", | |
| lines=4, | |
| placeholder="Any observation, sampling method, moisture condition, etc." | |
| ) | |
| with gr.Column(): | |
| sand_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Sand (%)") | |
| silt_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Silt (%)") | |
| clay_input = gr.Slider(0, 100, value=33.4, step=0.1, label="Clay (%)") | |
| consent = gr.Checkbox( | |
| label="I confirm this image and labels can be used for model improvement.", | |
| value=False | |
| ) | |
| submit_btn = gr.Button("Submit Contribution", variant="primary") | |
| contribution_status = gr.Markdown(label="Submission Status") | |
| with gr.Tab("Dataset Management"): | |
| gr.Markdown(""" | |
| **Upload** a dataset (ZIP) to contribute bulk data for model improvement. | |
| **Upload format:** ZIP containing a CSV file and image files. | |
| CSV columns: `filename`, `sand`, `silt`, `clay` (required). | |
| Optional: `weak_label`, `strong_label`, `source`, `location`, `notes`. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_file = gr.File(label="ZIP Dataset", file_types=[".zip"]) | |
| upload_consent = gr.Checkbox( | |
| label="I confirm these images and labels can be used for model improvement.", | |
| value=False, | |
| ) | |
| upload_btn = gr.Button("Upload Dataset", variant="primary") | |
| upload_status = gr.Markdown(label="Upload Status") | |
| with gr.Column(): | |
| stats_btn = gr.Button("Refresh Statistics") | |
| stats_display = gr.Markdown(label="Statistics") | |
| # Event handlers | |
| predict_btn.click( | |
| fn=predict_fn, | |
| inputs=input_image, | |
| outputs=[output_text, output_triangle] | |
| ) | |
| input_image.change( | |
| fn=predict_fn, | |
| inputs=input_image, | |
| outputs=[output_text, output_triangle] | |
| ) | |
| submit_btn.click( | |
| fn=submit_contribution_fn, | |
| inputs=[ | |
| contribution_image, | |
| sand_input, | |
| silt_input, | |
| clay_input, | |
| weak_label, | |
| strong_label, | |
| sample_source, | |
| location, | |
| notes, | |
| consent, | |
| ], | |
| outputs=[contribution_status] | |
| ) | |
| upload_btn.click( | |
| fn=upload_dataset_fn, | |
| inputs=[upload_file, upload_consent], | |
| outputs=[upload_status], | |
| ) | |
| stats_btn.click( | |
| fn=get_dataset_stats_fn, | |
| inputs=[], | |
| outputs=[stats_display], | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # MAIN | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Soil texture inference and contribution app") | |
| parser.add_argument("--checkpoint", type=str, default="finetuned_best.pth", | |
| help="Path to model checkpoint") | |
| parser.add_argument("--server_name", type=str, default="0.0.0.0", | |
| help="Gradio server host") | |
| parser.add_argument("--server_port", type=int, default=7860, | |
| help="Gradio server port") | |
| parser.add_argument("--share", action="store_true", | |
| help="Create a public share link") | |
| parser.add_argument("--attention_type", type=str, default="none", choices=["none", "se", "cbam"], | |
| help="Attention block used by inference model") | |
| parser.add_argument("--attention_reduction", type=int, default=16, | |
| help="Attention reduction ratio") | |
| parser.add_argument("--task_attention", action="store_true", | |
| help="Enable task-specific attention blocks") | |
| parser.add_argument("--allow_random_weights", action="store_true", | |
| help="Allow launching without checkpoint (debug only)") | |
| args = parser.parse_args() | |
| checkpoint_path = args.checkpoint | |
| if not Path(checkpoint_path).exists(): | |
| if not args.allow_random_weights: | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at {checkpoint_path}. " | |
| "Pass --allow_random_weights only for debugging." | |
| ) | |
| print(f"Warning: Checkpoint not found at {checkpoint_path}") | |
| print("Running with random weights for debug purposes.") | |
| checkpoint_path = None | |
| # Create and launch demo | |
| demo = create_demo( | |
| checkpoint_path=checkpoint_path, | |
| attention_type=args.attention_type, | |
| attention_reduction=args.attention_reduction, | |
| task_attention=args.task_attention, | |
| ) | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share | |
| ) | |