import argparse import csv import io import os import zipfile from pathlib import Path from typing import Tuple, Dict import numpy as np from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms import timm import gradio as gr import matplotlib.pyplot as plt from matplotlib.patches import Polygon try: from src.data_collection import DataCollectionManager, classify_from_percentages_simple except ImportError: import sys sys.path.insert(0, str(Path(__file__).resolve().parent / "src")) sys.path.insert(0, str(Path(__file__).resolve().parent)) from data_collection import DataCollectionManager, classify_from_percentages_simple # ============================================================================ # MODEL ARCHITECTURE (Embedded) # ============================================================================ class IdentityAttention(nn.Module): """No-op attention block.""" def forward(self, x: torch.Tensor) -> torch.Tensor: return x class SEFeatureAttention(nn.Module): """Squeeze-and-Excitation style attention for vector features.""" def __init__(self, feature_dim: int, reduction: int = 16): super().__init__() hidden_dim = max(8, feature_dim // reduction) self.fc = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, feature_dim), nn.Sigmoid(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.fc(x) class CBAMFeatureAttention(nn.Module): """CBAM-inspired attention for vector features.""" def __init__(self, feature_dim: int, reduction: int = 16): super().__init__() hidden_dim = max(8, feature_dim // reduction) self.mlp = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, feature_dim), ) self.gate = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: avg_desc = self.mlp(x) max_pool = x.max(dim=1, keepdim=True).values.expand_as(x) max_desc = self.mlp(max_pool) return x * self.gate(avg_desc + max_desc) def build_attention_block(attention_type: str, feature_dim: int, reduction: int = 16) -> nn.Module: key = (attention_type or "none").lower() if key == "none": return IdentityAttention() if key == "se": return SEFeatureAttention(feature_dim=feature_dim, reduction=reduction) if key == "cbam": return CBAMFeatureAttention(feature_dim=feature_dim, reduction=reduction) raise ValueError(f"Unknown attention type: {attention_type}") class SoilTextureModel(nn.Module): """ Multi-task model for soil texture analysis. Architecture: Image -> Backbone -> Shared Features -> Classification Head -> Texture Class -> Regression Head -> [Sand%, Silt%, Clay%] """ BACKBONE_CONFIGS = { 'efficientnet_v2_s': {'feature_dim': 1280, 'pretrained': 'tf_efficientnetv2_s'}, 'convnext_tiny': {'feature_dim': 768, 'pretrained': 'convnext_tiny'}, 'mobilevit_s': {'feature_dim': 640, 'pretrained': 'mobilevit_s'}, 'swin_tiny': {'feature_dim': 768, 'pretrained': 'swin_tiny_patch4_window7_224'}, 'resnet50': {'feature_dim': 2048, 'pretrained': 'resnet50'}, } def __init__( self, backbone_name: str = 'efficientnet_v2_s', num_classes: int = 12, dropout: float = 0.3, pretrained: bool = True, freeze_backbone: bool = False, attention_type: str = "none", attention_reduction: int = 16, task_attention: bool = False, ): super().__init__() self.backbone_name = backbone_name self.num_classes = num_classes # Get backbone configuration config = self.BACKBONE_CONFIGS.get(backbone_name, self.BACKBONE_CONFIGS['efficientnet_v2_s']) feature_dim = config['feature_dim'] # Load pretrained backbone self.backbone = timm.create_model( config['pretrained'], pretrained=pretrained, num_classes=0, # Remove classifier head global_pool='avg' ) # Freeze backbone if specified if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False self.shared_attention = build_attention_block( attention_type=attention_type, feature_dim=feature_dim, reduction=attention_reduction, ) if task_attention: self.class_attention = build_attention_block( attention_type=attention_type, feature_dim=feature_dim, reduction=attention_reduction, ) self.reg_attention = build_attention_block( attention_type=attention_type, feature_dim=feature_dim, reduction=attention_reduction, ) else: self.class_attention = IdentityAttention() self.reg_attention = IdentityAttention() # Classification head (texture type) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256, num_classes) ) # Regression head (Sand, Silt, Clay percentages) self.regressor = nn.Sequential( nn.Dropout(dropout), nn.Linear(feature_dim, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256, 3) # Sand, Silt, Clay ) # Initialize weights self._init_weights() def _init_weights(self): for m in [ self.shared_attention, self.class_attention, self.reg_attention, self.classifier, self.regressor, ]: for layer in m.modules(): if isinstance(layer, nn.Linear): nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') if layer.bias is not None: nn.init.constant_(layer.bias, 0) elif isinstance(layer, nn.BatchNorm1d): nn.init.constant_(layer.weight, 1) nn.init.constant_(layer.bias, 0) def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]: """Forward pass.""" # Extract features features = self.backbone(x) features = self.shared_attention(features) cls_features = self.class_attention(features) reg_features = self.reg_attention(features) # Classification class_logits = self.classifier(cls_features) # Regression (with softmax to ensure sum = 100) reg_output = self.regressor(reg_features) concentrations = F.softmax(reg_output, dim=1) * 100 # Scale to percentages result = { 'class_logits': class_logits, 'concentrations': concentrations } if return_features: result['features'] = features return result def create_model( model_type: str = 'full', backbone: str = 'efficientnet_v2_s', num_classes: int = 12, pretrained: bool = True, attention_type: str = "none", attention_reduction: int = 16, task_attention: bool = False, ) -> nn.Module: """Factory function to create model.""" model = SoilTextureModel( backbone_name=backbone, num_classes=num_classes, pretrained=pretrained, attention_type=attention_type, attention_reduction=attention_reduction, task_attention=task_attention, ) return model def format_prediction_markdown(result: Dict) -> str: """Create markdown output for inference results.""" sorted_probs = sorted(result["class_probabilities"].items(), key=lambda x: x[1], reverse=True) lines = [ "### Prediction Result", f"- **Texture Class:** `{result['class']}`", f"- **Confidence:** `{result['confidence'] * 100:.2f}%`", f"- **Sand / Silt / Clay:** `{result['sand']:.2f}% / {result['silt']:.2f}% / {result['clay']:.2f}%`", "", "**Top Probabilities**", ] for class_name, prob in sorted_probs[:5]: lines.append(f"- {class_name}: {prob * 100:.2f}%") return "\n".join(lines) # ============================================================================ # SOIL TEXTURE TRIANGLE VISUALIZATION # ============================================================================ def create_texture_triangle(sand: float, silt: float, clay: float, predicted_class: str, confidence: float = None, top_probs: list = None) -> np.ndarray: """ Create USDA Soil Texture Triangle visualization with correct boundaries. """ fig, ax = plt.subplots(1, 1, figsize=(14, 12), facecolor='white', dpi=150) # Helper function to convert soil percentages to triangle coordinates def soil_to_coords(sand_pct, silt_pct, clay_pct): x = silt_pct/100 + clay_pct/200 y = clay_pct/100 * np.sqrt(3)/2 return x, y # USDA Soil Texture Triangle regions with correct boundaries regions = [ ('Sand', [(100, 0, 0), (85, 15, 0), (90, 0, 10)], '#FFE4B5'), ('Loamy Sand', [(85, 15, 0), (70, 30, 0), (85, 0, 15), (90, 0, 10)], '#FFDAB9'), ('Sandy Loam', [(70, 30, 0), (50, 50, 0), (42.5, 50, 7.5), (52.5, 40, 7.5), (52.5, 27.5, 20), (80, 0, 20), (85, 0, 15)], '#F4A460'), ('Loam', [(42.5, 50, 7.5), (22.5, 50, 27.5), (45, 27.5, 27.5), (52.5, 27.5, 20), (52.5, 40, 7.5)], '#DEB887'), ('Silt Loam', [(50, 50, 0), (20, 80, 0), (7.5, 80, 12.5), (0, 87.5, 12.5), (0, 72.5, 27.5), (22.5, 50, 27.5)], '#D2B48C'), ('Silt', [(20, 80, 0), (0, 100, 0), (0, 87.5, 12.5), (7.5, 80, 12.5)], '#C0C0C0'), ('Sandy Clay Loam', [(80, 0, 20), (52.5, 27.5, 20), (45, 27.5, 27.5), (45, 20, 35), (65, 0, 35)], '#CD853F'), ('Clay Loam', [(45, 27.5, 27.5), (20, 52.5, 27.5), (20, 40, 40), (45, 15, 40)], '#D2691E'), ('Silty Clay Loam', [(0, 72.5, 27.5), (0, 60, 40), (20, 40, 40), (20, 52.5, 27.5)], '#B8860B'), ('Sandy Clay', [(65, 0, 35), (45, 20, 35), (45, 0, 55)], '#A0522D'), ('Silty Clay', [(20, 40, 40), (0, 60, 40), (0, 40, 60)], '#8B4513'), ('Clay', [(45, 15, 40), (20, 40, 40), (0, 40, 60), (0, 0, 100), (45, 0, 55)], '#654321'), ] # Draw colored regions with border lines for name, vertices_pct, color in regions: vertices_xy = [soil_to_coords(s, si, c) for s, si, c in vertices_pct] region_patch = Polygon(vertices_xy, facecolor=color, edgecolor='#333', linewidth=1.2, alpha=0.8, zorder=1) ax.add_patch(region_patch) # Add label center_x = np.mean([v[0] for v in vertices_xy]) center_y = np.mean([v[1] for v in vertices_xy]) ax.text(center_x, center_y, name, fontsize=12, ha='center', va='center', weight='bold', zorder=2) # Draw triangle outline triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]]) tri_patch = Polygon(triangle, fill=False, edgecolor='black', linewidth=4, zorder=3) ax.add_patch(tri_patch) # Add corner labels ax.text(0, -0.05, '100% Sand', fontsize=16, ha='center', weight='bold') ax.text(1, -0.05, '100% Silt', fontsize=16, ha='center', weight='bold') ax.text(0.5, np.sqrt(3)/2 + 0.03, '100% Clay', fontsize=16, ha='center', weight='bold') # Add grid lines for pct in range(5, 100, 5): y = pct/100 * np.sqrt(3)/2 x_left = pct/200 x_right = 1 - pct/200 sand_pct = pct p1 = soil_to_coords(sand_pct, 0, 100-sand_pct) p2 = soil_to_coords(sand_pct, 100-sand_pct, 0) silt_pct = pct p3 = soil_to_coords(0, silt_pct, 100-silt_pct) p4 = soil_to_coords(100-silt_pct, silt_pct, 0) if pct % 10 == 0: ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.3, linewidth=1.0, zorder=0) ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0) ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0) ax.text(x_left - 0.03, y, f'{pct}', fontsize=11, alpha=0.7, weight='bold') else: ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.15, linewidth=0.6, zorder=0) ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0) ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0) # Plot prediction point pred_x, pred_y = soil_to_coords(sand, silt, clay) ax.plot(pred_x, pred_y, 'o', markersize=22, markerfacecolor='red', markeredgecolor='darkred', markeredgewidth=3.5, zorder=5) # Add annotation offset_x = 0.15 if pred_x < 0.7 else -0.15 offset_y = 0.08 ax.annotate(f'{predicted_class}\n({sand:.0f}%, {silt:.0f}%, {clay:.0f}%)', xy=(pred_x, pred_y), xytext=(pred_x + offset_x, pred_y + offset_y), fontsize=14, fontweight='bold', arrowprops=dict(arrowstyle='->', lw=2.5, color='darkred'), bbox=dict(boxstyle='round,pad=0.6', facecolor='white', edgecolor='darkred', lw=2.5), ha='center', zorder=6) # Add prediction information boxes if confidence is not None and top_probs is not None: # Left box - Prediction and Composition left_text = f"Predicted Class:\n{predicted_class}\n\n" left_text += f"Confidence: {confidence*100:.1f}%\n\n" left_text += f"Composition:\n" left_text += f"Sand: {sand:.1f}%\n" left_text += f"Silt: {silt:.1f}%\n" left_text += f"Clay: {clay:.1f}%" ax.text(0.05, 0.82, left_text, fontsize=16, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.9', facecolor='white', edgecolor='black', linewidth=2.5, alpha=0.95), zorder=7, family='monospace', weight='bold') # Right box - Top 5 Probabilities right_text = "Top 5 Probabilities:\n\n" for i, (cls, prob) in enumerate(top_probs[:5], 1): right_text += f"{i}. {cls}: {prob*100:.1f}%\n" ax.text(0.75, 0.82, right_text, fontsize=16, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.9', facecolor='white', edgecolor='black', linewidth=2.5, alpha=0.95), zorder=7, family='monospace', weight='bold') ax.set_xlim(-0.08, 1.08) ax.set_ylim(-0.08, np.sqrt(3)/2 + 0.06) ax.set_aspect('equal') ax.axis('off') ax.set_title('USDA Soil Texture Triangle', fontsize=20, fontweight='bold', pad=8) fig.tight_layout() fig.canvas.draw() img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,)) img = img[:, :, :3] plt.close(fig) return img # ============================================================================ # PREDICTOR CLASS # ============================================================================ def classify_from_percentages(sand: float, silt: float, clay: float) -> str: """ Determine USDA texture class from Sand/Silt/Clay percentages. Uses official USDA classification boundaries. """ # Normalize to ensure sum = 100 total = sand + silt + clay if total > 0: sand = sand / total * 100 silt = silt / total * 100 clay = clay / total * 100 # USDA classification rules (order matters for overlapping boundaries) if clay >= 40: if silt >= 40: return 'Silty Clay' elif sand >= 45: return 'Sandy Clay' else: return 'Clay' elif clay >= 35: if sand >= 45: return 'Sandy Clay' elif silt < 20: return 'Sandy Clay' else: return 'Clay Loam' elif clay >= 27: if sand >= 20 and sand < 45: return 'Clay Loam' elif silt >= 28 and silt < 40: return 'Clay Loam' elif silt >= 40: return 'Silty Clay Loam' else: return 'Sandy Clay Loam' elif clay >= 20: if sand >= 45: return 'Sandy Clay Loam' elif silt >= 28 and sand < 45: return 'Clay Loam' elif silt >= 50: return 'Silty Clay Loam' else: return 'Sandy Clay Loam' elif clay >= 12: if silt >= 50 and clay >= 12 and clay < 27: return 'Silt Loam' elif silt >= 50 and silt < 80: return 'Silt Loam' elif silt >= 80 and clay < 12: return 'Silt' elif sand >= 52: return 'Sandy Loam' else: return 'Loam' elif clay >= 7: if silt >= 50: return 'Silt Loam' elif silt >= 28 and silt < 50 and sand < 52: return 'Loam' else: return 'Sandy Loam' else: # clay < 7 if silt >= 80: return 'Silt' elif silt >= 50: return 'Silt Loam' elif sand >= 85 and silt + 1.5 * clay < 15: return 'Sand' elif sand >= 70 and sand < 85: return 'Loamy Sand' elif sand >= 43 and sand < 52: return 'Sandy Loam' if silt < 50 else 'Silt Loam' elif sand >= 52: return 'Sandy Loam' else: return 'Loam' class SoilTexturePredictor: """ Inference wrapper for soil texture prediction. """ CLASSES = [ 'Sand', 'Loamy Sand', 'Sandy Loam', 'Loam', 'Silt Loam', 'Silt', 'Sandy Clay Loam', 'Clay Loam', 'Silty Clay Loam', 'Sandy Clay', 'Silty Clay', 'Clay' ] def __init__( self, checkpoint_path: str = None, device: str = None, attention_type: str = "none", attention_reduction: int = 16, task_attention: bool = False, ): self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Create model self.model = create_model( model_type='full', backbone='efficientnet_v2_s', num_classes=len(self.CLASSES), pretrained=False, attention_type=attention_type, attention_reduction=attention_reduction, task_attention=task_attention, ) # Load checkpoint if provided if checkpoint_path and Path(checkpoint_path).exists(): print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) else: print("No checkpoint provided, using random weights (for demo)") self.model.to(self.device) self.model.eval() # Transform self.transform = transforms.Compose([ transforms.Resize((500, 500)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) @torch.no_grad() def predict(self, image: Image.Image) -> Dict: """ Predict soil texture class and concentrations. """ # Preprocess img_tensor = self.transform(image).unsqueeze(0).to(self.device) # Forward pass output = self.model(img_tensor) # Class prediction from classification head (for reference) class_probs = F.softmax(output['class_logits'], dim=1).cpu().numpy()[0] # Concentration prediction concentrations = output['concentrations'].cpu().numpy()[0] sand, silt, clay = concentrations # Ensure they sum to 100 total = sand + silt + clay sand = sand / total * 100 silt = silt / total * 100 clay = clay / total * 100 # Derive class from percentages to ensure consistency class_name = classify_from_percentages(sand, silt, clay) confidence = class_probs[self.CLASSES.index(class_name)] return { 'class': class_name, 'confidence': confidence, 'class_probabilities': {self.CLASSES[i]: float(p) for i, p in enumerate(class_probs)}, 'sand': sand, 'silt': silt, 'clay': clay } def predict_with_visualization(self, image: Image.Image) -> Tuple[str, np.ndarray, Dict]: """Predict and create visualization.""" result = self.predict(image) # Sort by probability and show top 5 sorted_probs = sorted(result['class_probabilities'].items(), key=lambda x: x[1], reverse=True)[:5] # Create texture triangle triangle_img = create_texture_triangle( result['sand'], result['silt'], result['clay'], result['class'], confidence=result['confidence'], top_probs=sorted_probs ) text_output = format_prediction_markdown(result) return text_output, triangle_img, result # ============================================================================ # GRADIO INTERFACE # ============================================================================ def create_demo( checkpoint_path: str = None, attention_type: str = "none", attention_reduction: int = 16, task_attention: bool = False, ): """Create Gradio demo interface.""" # Initialize predictor predictor = SoilTexturePredictor( checkpoint_path=checkpoint_path, attention_type=attention_type, attention_reduction=attention_reduction, task_attention=task_attention, ) collection_manager = DataCollectionManager() collection_manager.ensure_storage() collection_manager.start_scheduler() def to_pil_image(image): """Convert possible Gradio image input to PIL.""" if isinstance(image, Image.Image): return image.convert("RGB") if isinstance(image, np.ndarray): return Image.fromarray(image).convert("RGB") raise ValueError("Unsupported image format.") def predict_fn(image): """Gradio prediction function.""" if image is None: return "Please upload an image.", None image = to_pil_image(image) # Get prediction text_output, triangle_img, _ = predictor.predict_with_visualization(image) return text_output, triangle_img def submit_contribution_fn( image, sand, silt, clay, weak_label, strong_label, sample_source, location, notes, consent ): """Persist user-contributed image + composition for future training.""" if image is None: return "Submission failed: please upload a soil image." image = to_pil_image(image) validation = collection_manager.validate_submission( sand=sand, silt=silt, clay=clay, consent=consent, image=image, ) if not validation.ok: return f"Submission failed: {validation.message}" prediction = predictor.predict(image) user_class = classify_from_percentages_simple(sand, silt, clay) submission_id = collection_manager.create_submission_id() save_result = collection_manager.save_submission( image=image, submission_id=submission_id, sand=sand, silt=silt, clay=clay, user_class=user_class, weak_label=weak_label, strong_label=strong_label, prediction=prediction, sample_source=sample_source, location=location, notes=notes, total=validation.total, ) image_path = save_result.get("image_path", "") is_duplicate = save_result.get("is_duplicate", "0") == "1" duplicate_of_submission = save_result.get("duplicate_of_submission", "") export_bundles = collection_manager.maybe_trigger_exports() export_note = "" if export_bundles: export_note = "\n- Auto-export triggered:\n" + "\n".join([f" - `{bundle}`" for bundle in export_bundles]) dedup_note = "" if is_duplicate: dedup_note = f"\n- Duplicate image detected. Reused existing sample from `{duplicate_of_submission}`." return ( "### Submission Saved\n" f"- Submission ID: `{submission_id}`\n" f"- Stored image: `{image_path}`\n" f"- User label class: `{user_class}`\n" f"- Model prediction: `{prediction['class']}` ({prediction['confidence'] * 100:.2f}%)\n" f"- Weak label: `{weak_label or ''}`\n" f"- Strong label: `{strong_label or ''}`\n" "- Data was appended to `data/community_submissions/submissions.csv`.\n" "- Daily export uses background scheduler; high disk usage triggers immediate export." f"{dedup_note}" f"{export_note}" ) def get_dataset_stats_fn(): """Get statistics about the current dataset.""" cfg = collection_manager.config num_submissions = 0 if cfg.csv_path.exists(): with cfg.csv_path.open("r", encoding="utf-8") as f: reader = csv.reader(f) next(reader, None) num_submissions = sum(1 for _ in reader) num_images = 0 total_size_bytes = 0 if cfg.images_dir.exists(): for p in cfg.images_dir.iterdir(): if p.is_file(): num_images += 1 total_size_bytes += p.stat().st_size total_size_mb = total_size_bytes / (1024 * 1024) return ( f"### Dataset Statistics\n" f"- **Total submissions:** {num_submissions}\n" f"- **Total images:** {num_images}\n" f"- **Total image size:** {total_size_mb:.1f} MB\n" ) def upload_dataset_fn(zip_file, upload_consent): """Process uploaded ZIP dataset with images and CSV.""" if zip_file is None: return "Please upload a ZIP file." if not upload_consent: return "Please confirm consent before uploading." zip_path = zip_file if isinstance(zip_file, str) else zip_file.name if not zipfile.is_zipfile(zip_path): return "Invalid ZIP file." max_entries = 10000 max_total_size = 500 * 1024 * 1024 results = {"added": 0, "skipped": 0, "errors": []} try: with zipfile.ZipFile(zip_path, "r") as zf: entries = zf.infolist() if len(entries) > max_entries: return f"ZIP has too many entries ({len(entries)}). Max: {max_entries}." total_size = sum(e.file_size for e in entries) if total_size > max_total_size: return f"ZIP too large ({total_size / 1024 / 1024:.0f} MB). Max: {max_total_size // (1024 * 1024)} MB." csv_entries = [ e for e in entries if e.filename.endswith(".csv") and not e.filename.startswith("__") ] if not csv_entries: return "No CSV found in ZIP. Expected CSV with columns: filename, sand, silt, clay." with zf.open(csv_entries[0]) as csv_file: content = csv_file.read().decode("utf-8") reader = csv.DictReader(io.StringIO(content)) headers = set(reader.fieldnames or []) required = {"filename", "sand", "silt", "clay"} if not required.issubset(headers): return ( f"CSV must have columns: {', '.join(sorted(required))}. " f"Found: {', '.join(sorted(headers))}" ) for row in reader: try: fname = row["filename"].strip() sand = float(row["sand"]) silt = float(row["silt"]) clay = float(row["clay"]) vals = [sand, silt, clay] if any(v < 0 or v > 100 for v in vals): results["errors"].append(f"{fname}: values out of range") results["skipped"] += 1 continue total = sand + silt + clay if abs(total - 100.0) > 1.0: results["errors"].append(f"{fname}: sum={total:.1f}, must be ~100") results["skipped"] += 1 continue matches = [e for e in entries if Path(e.filename).name == fname] if not matches: results["errors"].append(f"Image not found in ZIP: {fname}") results["skipped"] += 1 continue with zf.open(matches[0]) as img_bytes: image = Image.open(img_bytes).convert("RGB") if image.width * image.height > collection_manager.config.max_image_pixels: results["errors"].append(f"{fname}: image too large") results["skipped"] += 1 continue prediction = predictor.predict(image) user_class = classify_from_percentages_simple(sand, silt, clay) submission_id = collection_manager.create_submission_id() collection_manager.save_submission( image=image, submission_id=submission_id, sand=sand, silt=silt, clay=clay, user_class=user_class, weak_label=row.get("weak_label", ""), strong_label=row.get("strong_label", ""), prediction=prediction, sample_source=row.get("source", ""), location=row.get("location", ""), notes=row.get("notes", ""), total=total, ) results["added"] += 1 except Exception as e: results["errors"].append(f"{row.get('filename', '?')}: {e}") results["skipped"] += 1 except Exception as e: return f"Failed to process ZIP: {e}" error_summary = "" if results["errors"]: shown = results["errors"][:20] error_summary = "\n\n**Errors:**\n" + "\n".join(f"- {e}" for e in shown) if len(results["errors"]) > 20: error_summary += f"\n- ... and {len(results['errors']) - 20} more" return ( f"### Upload Complete\n" f"- **Added:** {results['added']} submissions\n" f"- **Skipped:** {results['skipped']}\n" f"{error_summary}" ) # Create interface with gr.Blocks(title="Soil Texture Classifier") as demo: gr.Markdown(""" # Soil Texture Classification 1. Use **Inference** to predict texture class and composition from image. 2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training. 3. Use **Dataset Management** to bulk-upload a ZIP dataset for model improvement. """) with gr.Tabs(): with gr.Tab("Inference"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Soil Image", type="pil") predict_btn = gr.Button("Analyze", variant="primary") gr.Markdown(""" **Tips:** - Use close-up images of soil surface - Ensure good lighting - Avoid shadows and reflections """) with gr.Column(): output_text = gr.Markdown(label="Results") output_triangle = gr.Image(label="USDA Texture Triangle") with gr.Tab("Contribute Data"): gr.Markdown(""" Upload a soil image with measured Sand/Silt/Clay percentages. This data will be stored for manual quality checks and future retraining. You can optionally submit weak/strong labels for better curation quality. """) with gr.Row(): with gr.Column(): contribution_image = gr.Image(label="Soil Image for Contribution", type="pil") weak_label = gr.Dropdown( choices=[""] + SoilTexturePredictor.CLASSES, value="", allow_custom_value=True, label="Weak Label (Optional)" ) strong_label = gr.Dropdown( choices=[""] + SoilTexturePredictor.CLASSES, value="", allow_custom_value=True, label="Strong Label (Optional)" ) sample_source = gr.Textbox( label="Sample Source", placeholder="e.g., field site, experiment ID, sample batch" ) location = gr.Textbox( label="Location (Optional)", placeholder="e.g., Iowa, USA" ) notes = gr.Textbox( label="Notes (Optional)", lines=4, placeholder="Any observation, sampling method, moisture condition, etc." ) with gr.Column(): sand_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Sand (%)") silt_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Silt (%)") clay_input = gr.Slider(0, 100, value=33.4, step=0.1, label="Clay (%)") consent = gr.Checkbox( label="I confirm this image and labels can be used for model improvement.", value=False ) submit_btn = gr.Button("Submit Contribution", variant="primary") contribution_status = gr.Markdown(label="Submission Status") with gr.Tab("Dataset Management"): gr.Markdown(""" **Upload** a dataset (ZIP) to contribute bulk data for model improvement. **Upload format:** ZIP containing a CSV file and image files. CSV columns: `filename`, `sand`, `silt`, `clay` (required). Optional: `weak_label`, `strong_label`, `source`, `location`, `notes`. """) with gr.Row(): with gr.Column(): upload_file = gr.File(label="ZIP Dataset", file_types=[".zip"]) upload_consent = gr.Checkbox( label="I confirm these images and labels can be used for model improvement.", value=False, ) upload_btn = gr.Button("Upload Dataset", variant="primary") upload_status = gr.Markdown(label="Upload Status") with gr.Column(): stats_btn = gr.Button("Refresh Statistics") stats_display = gr.Markdown(label="Statistics") # Event handlers predict_btn.click( fn=predict_fn, inputs=input_image, outputs=[output_text, output_triangle] ) input_image.change( fn=predict_fn, inputs=input_image, outputs=[output_text, output_triangle] ) submit_btn.click( fn=submit_contribution_fn, inputs=[ contribution_image, sand_input, silt_input, clay_input, weak_label, strong_label, sample_source, location, notes, consent, ], outputs=[contribution_status] ) upload_btn.click( fn=upload_dataset_fn, inputs=[upload_file, upload_consent], outputs=[upload_status], ) stats_btn.click( fn=get_dataset_stats_fn, inputs=[], outputs=[stats_display], ) return demo # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Soil texture inference and contribution app") parser.add_argument("--checkpoint", type=str, default="finetuned_best.pth", help="Path to model checkpoint") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Gradio server host") parser.add_argument("--server_port", type=int, default=7860, help="Gradio server port") parser.add_argument("--share", action="store_true", help="Create a public share link") parser.add_argument("--attention_type", type=str, default="none", choices=["none", "se", "cbam"], help="Attention block used by inference model") parser.add_argument("--attention_reduction", type=int, default=16, help="Attention reduction ratio") parser.add_argument("--task_attention", action="store_true", help="Enable task-specific attention blocks") parser.add_argument("--allow_random_weights", action="store_true", help="Allow launching without checkpoint (debug only)") args = parser.parse_args() checkpoint_path = args.checkpoint if not Path(checkpoint_path).exists(): if not args.allow_random_weights: raise FileNotFoundError( f"Checkpoint not found at {checkpoint_path}. " "Pass --allow_random_weights only for debugging." ) print(f"Warning: Checkpoint not found at {checkpoint_path}") print("Running with random weights for debug purposes.") checkpoint_path = None # Create and launch demo demo = create_demo( checkpoint_path=checkpoint_path, attention_type=args.attention_type, attention_reduction=args.attention_reduction, task_attention=args.task_attention, ) demo.launch( server_name=args.server_name, server_port=args.server_port, share=args.share )