Iridium-193 commited on
Commit
49dd243
·
verified ·
1 Parent(s): 58969d7

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +5 -7
  2. app.py +1014 -1013
  3. collection_common.py +64 -0
  4. data_collection.py +728 -0
  5. finetuned_best.pth +2 -2
  6. requirements.txt +18 -18
README.md CHANGED
@@ -1,12 +1,10 @@
1
  ---
2
- title: SoilTextureClassification
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Soil Texture Classification
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
app.py CHANGED
@@ -1,1013 +1,1014 @@
1
- import argparse
2
- import csv
3
- import io
4
- import os
5
- import zipfile
6
- from pathlib import Path
7
- from typing import Tuple, Dict
8
- import numpy as np
9
- from PIL import Image
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- from torchvision import transforms
14
- import timm
15
- import gradio as gr
16
- import matplotlib.pyplot as plt
17
- from matplotlib.patches import Polygon
18
-
19
- try:
20
- from src.data_collection import DataCollectionManager, classify_from_percentages_simple
21
- except ImportError:
22
- import sys
23
-
24
- sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
25
- from data_collection import DataCollectionManager, classify_from_percentages_simple
26
-
27
-
28
- # ============================================================================
29
- # MODEL ARCHITECTURE (Embedded)
30
- # ============================================================================
31
-
32
- class IdentityAttention(nn.Module):
33
- """No-op attention block."""
34
-
35
- def forward(self, x: torch.Tensor) -> torch.Tensor:
36
- return x
37
-
38
-
39
- class SEFeatureAttention(nn.Module):
40
- """Squeeze-and-Excitation style attention for vector features."""
41
-
42
- def __init__(self, feature_dim: int, reduction: int = 16):
43
- super().__init__()
44
- hidden_dim = max(8, feature_dim // reduction)
45
- self.fc = nn.Sequential(
46
- nn.Linear(feature_dim, hidden_dim),
47
- nn.ReLU(inplace=True),
48
- nn.Linear(hidden_dim, feature_dim),
49
- nn.Sigmoid(),
50
- )
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- return x * self.fc(x)
54
-
55
-
56
- class CBAMFeatureAttention(nn.Module):
57
- """CBAM-inspired attention for vector features."""
58
-
59
- def __init__(self, feature_dim: int, reduction: int = 16):
60
- super().__init__()
61
- hidden_dim = max(8, feature_dim // reduction)
62
- self.mlp = nn.Sequential(
63
- nn.Linear(feature_dim, hidden_dim),
64
- nn.ReLU(inplace=True),
65
- nn.Linear(hidden_dim, feature_dim),
66
- )
67
- self.gate = nn.Sigmoid()
68
-
69
- def forward(self, x: torch.Tensor) -> torch.Tensor:
70
- avg_desc = self.mlp(x)
71
- max_pool = x.max(dim=1, keepdim=True).values.expand_as(x)
72
- max_desc = self.mlp(max_pool)
73
- return x * self.gate(avg_desc + max_desc)
74
-
75
-
76
- def build_attention_block(attention_type: str, feature_dim: int, reduction: int = 16) -> nn.Module:
77
- key = (attention_type or "none").lower()
78
- if key == "none":
79
- return IdentityAttention()
80
- if key == "se":
81
- return SEFeatureAttention(feature_dim=feature_dim, reduction=reduction)
82
- if key == "cbam":
83
- return CBAMFeatureAttention(feature_dim=feature_dim, reduction=reduction)
84
- raise ValueError(f"Unknown attention type: {attention_type}")
85
-
86
-
87
- class SoilTextureModel(nn.Module):
88
- """
89
- Multi-task model for soil texture analysis.
90
-
91
- Architecture:
92
- Image -> Backbone -> Shared Features -> Classification Head -> Texture Class
93
- -> Regression Head -> [Sand%, Silt%, Clay%]
94
- """
95
-
96
- BACKBONE_CONFIGS = {
97
- 'efficientnet_v2_s': {'feature_dim': 1280, 'pretrained': 'tf_efficientnetv2_s'},
98
- 'convnext_tiny': {'feature_dim': 768, 'pretrained': 'convnext_tiny'},
99
- 'mobilevit_s': {'feature_dim': 640, 'pretrained': 'mobilevit_s'},
100
- 'swin_tiny': {'feature_dim': 768, 'pretrained': 'swin_tiny_patch4_window7_224'},
101
- 'resnet50': {'feature_dim': 2048, 'pretrained': 'resnet50'},
102
- }
103
-
104
- def __init__(
105
- self,
106
- backbone_name: str = 'efficientnet_v2_s',
107
- num_classes: int = 12,
108
- dropout: float = 0.3,
109
- pretrained: bool = True,
110
- freeze_backbone: bool = False,
111
- attention_type: str = "none",
112
- attention_reduction: int = 16,
113
- task_attention: bool = False,
114
- ):
115
- super().__init__()
116
-
117
- self.backbone_name = backbone_name
118
- self.num_classes = num_classes
119
-
120
- # Get backbone configuration
121
- config = self.BACKBONE_CONFIGS.get(backbone_name, self.BACKBONE_CONFIGS['efficientnet_v2_s'])
122
- feature_dim = config['feature_dim']
123
-
124
- # Load pretrained backbone
125
- self.backbone = timm.create_model(
126
- config['pretrained'],
127
- pretrained=pretrained,
128
- num_classes=0, # Remove classifier head
129
- global_pool='avg'
130
- )
131
-
132
- # Freeze backbone if specified
133
- if freeze_backbone:
134
- for param in self.backbone.parameters():
135
- param.requires_grad = False
136
-
137
- self.shared_attention = build_attention_block(
138
- attention_type=attention_type,
139
- feature_dim=feature_dim,
140
- reduction=attention_reduction,
141
- )
142
- if task_attention:
143
- self.class_attention = build_attention_block(
144
- attention_type=attention_type,
145
- feature_dim=feature_dim,
146
- reduction=attention_reduction,
147
- )
148
- self.reg_attention = build_attention_block(
149
- attention_type=attention_type,
150
- feature_dim=feature_dim,
151
- reduction=attention_reduction,
152
- )
153
- else:
154
- self.class_attention = IdentityAttention()
155
- self.reg_attention = IdentityAttention()
156
-
157
- # Classification head (texture type)
158
- self.classifier = nn.Sequential(
159
- nn.Dropout(dropout),
160
- nn.Linear(feature_dim, 512),
161
- nn.BatchNorm1d(512),
162
- nn.ReLU(inplace=True),
163
- nn.Dropout(dropout * 0.5),
164
- nn.Linear(512, 256),
165
- nn.ReLU(inplace=True),
166
- nn.Linear(256, num_classes)
167
- )
168
-
169
- # Regression head (Sand, Silt, Clay percentages)
170
- self.regressor = nn.Sequential(
171
- nn.Dropout(dropout),
172
- nn.Linear(feature_dim, 512),
173
- nn.BatchNorm1d(512),
174
- nn.ReLU(inplace=True),
175
- nn.Dropout(dropout * 0.5),
176
- nn.Linear(512, 256),
177
- nn.ReLU(inplace=True),
178
- nn.Linear(256, 3) # Sand, Silt, Clay
179
- )
180
-
181
- # Initialize weights
182
- self._init_weights()
183
-
184
- def _init_weights(self):
185
- for m in [
186
- self.shared_attention,
187
- self.class_attention,
188
- self.reg_attention,
189
- self.classifier,
190
- self.regressor,
191
- ]:
192
- for layer in m.modules():
193
- if isinstance(layer, nn.Linear):
194
- nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
195
- if layer.bias is not None:
196
- nn.init.constant_(layer.bias, 0)
197
- elif isinstance(layer, nn.BatchNorm1d):
198
- nn.init.constant_(layer.weight, 1)
199
- nn.init.constant_(layer.bias, 0)
200
-
201
- def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
202
- """Forward pass."""
203
- # Extract features
204
- features = self.backbone(x)
205
- features = self.shared_attention(features)
206
- cls_features = self.class_attention(features)
207
- reg_features = self.reg_attention(features)
208
-
209
- # Classification
210
- class_logits = self.classifier(cls_features)
211
-
212
- # Regression (with softmax to ensure sum = 100)
213
- reg_output = self.regressor(reg_features)
214
- concentrations = F.softmax(reg_output, dim=1) * 100 # Scale to percentages
215
-
216
- result = {
217
- 'class_logits': class_logits,
218
- 'concentrations': concentrations
219
- }
220
-
221
- if return_features:
222
- result['features'] = features
223
-
224
- return result
225
-
226
-
227
- def create_model(
228
- model_type: str = 'full',
229
- backbone: str = 'efficientnet_v2_s',
230
- num_classes: int = 12,
231
- pretrained: bool = True,
232
- attention_type: str = "none",
233
- attention_reduction: int = 16,
234
- task_attention: bool = False,
235
- ) -> nn.Module:
236
- """Factory function to create model."""
237
- model = SoilTextureModel(
238
- backbone_name=backbone,
239
- num_classes=num_classes,
240
- pretrained=pretrained,
241
- attention_type=attention_type,
242
- attention_reduction=attention_reduction,
243
- task_attention=task_attention,
244
- )
245
- return model
246
-
247
-
248
- def format_prediction_markdown(result: Dict) -> str:
249
- """Create markdown output for inference results."""
250
- sorted_probs = sorted(result["class_probabilities"].items(), key=lambda x: x[1], reverse=True)
251
- lines = [
252
- "### Prediction Result",
253
- f"- **Texture Class:** `{result['class']}`",
254
- f"- **Confidence:** `{result['confidence'] * 100:.2f}%`",
255
- f"- **Sand / Silt / Clay:** `{result['sand']:.2f}% / {result['silt']:.2f}% / {result['clay']:.2f}%`",
256
- "",
257
- "**Top Probabilities**",
258
- ]
259
- for class_name, prob in sorted_probs[:5]:
260
- lines.append(f"- {class_name}: {prob * 100:.2f}%")
261
-
262
- return "\n".join(lines)
263
-
264
-
265
- # ============================================================================
266
- # SOIL TEXTURE TRIANGLE VISUALIZATION
267
- # ============================================================================
268
-
269
- def create_texture_triangle(sand: float, silt: float, clay: float, predicted_class: str,
270
- confidence: float = None, top_probs: list = None) -> np.ndarray:
271
- """
272
- Create USDA Soil Texture Triangle visualization with correct boundaries.
273
- """
274
- fig, ax = plt.subplots(1, 1, figsize=(14, 12), facecolor='white', dpi=150)
275
-
276
- # Helper function to convert soil percentages to triangle coordinates
277
- def soil_to_coords(sand_pct, silt_pct, clay_pct):
278
- x = silt_pct/100 + clay_pct/200
279
- y = clay_pct/100 * np.sqrt(3)/2
280
- return x, y
281
-
282
- # USDA Soil Texture Triangle regions with correct boundaries
283
- regions = [
284
- ('Sand', [(100, 0, 0), (85, 15, 0), (90, 0, 10)], '#FFE4B5'),
285
- ('Loamy Sand', [(85, 15, 0), (70, 30, 0), (85, 0, 15), (90, 0, 10)], '#FFDAB9'),
286
- ('Sandy Loam', [(70, 30, 0), (50, 50, 0), (42.5, 50, 7.5), (52.5, 40, 7.5), (52.5, 27.5, 20), (80, 0, 20), (85, 0, 15)], '#F4A460'),
287
- ('Loam', [(42.5, 50, 7.5), (22.5, 50, 27.5), (45, 27.5, 27.5), (52.5, 27.5, 20), (52.5, 40, 7.5)], '#DEB887'),
288
- ('Silt Loam', [(50, 50, 0), (20, 80, 0), (7.5, 80, 12.5), (0, 87.5, 12.5), (0, 72.5, 27.5), (22.5, 50, 27.5)], '#D2B48C'),
289
- ('Silt', [(20, 80, 0), (0, 100, 0), (0, 87.5, 12.5), (7.5, 80, 12.5)], '#C0C0C0'),
290
- ('Sandy Clay Loam', [(80, 0, 20), (52.5, 27.5, 20), (45, 27.5, 27.5), (45, 20, 35), (65, 0, 35)], '#CD853F'),
291
- ('Clay Loam', [(45, 27.5, 27.5), (20, 52.5, 27.5), (20, 40, 40), (45, 15, 40)], '#D2691E'),
292
- ('Silty Clay Loam', [(0, 72.5, 27.5), (0, 60, 40), (20, 40, 40), (20, 52.5, 27.5)], '#B8860B'),
293
- ('Sandy Clay', [(65, 0, 35), (45, 20, 35), (45, 0, 55)], '#A0522D'),
294
- ('Silty Clay', [(20, 40, 40), (0, 60, 40), (0, 40, 60)], '#8B4513'),
295
- ('Clay', [(45, 15, 40), (20, 40, 40), (0, 40, 60), (0, 0, 100), (45, 0, 55)], '#654321'),
296
- ]
297
-
298
- # Draw colored regions with border lines
299
- for name, vertices_pct, color in regions:
300
- vertices_xy = [soil_to_coords(s, si, c) for s, si, c in vertices_pct]
301
- region_patch = Polygon(vertices_xy, facecolor=color, edgecolor='#333',
302
- linewidth=1.2, alpha=0.8, zorder=1)
303
- ax.add_patch(region_patch)
304
- # Add label
305
- center_x = np.mean([v[0] for v in vertices_xy])
306
- center_y = np.mean([v[1] for v in vertices_xy])
307
- ax.text(center_x, center_y, name, fontsize=12, ha='center',
308
- va='center', weight='bold', zorder=2)
309
-
310
- # Draw triangle outline
311
- triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
312
- tri_patch = Polygon(triangle, fill=False, edgecolor='black', linewidth=4, zorder=3)
313
- ax.add_patch(tri_patch)
314
-
315
- # Add corner labels
316
- ax.text(0, -0.05, '100% Sand', fontsize=16, ha='center', weight='bold')
317
- ax.text(1, -0.05, '100% Silt', fontsize=16, ha='center', weight='bold')
318
- ax.text(0.5, np.sqrt(3)/2 + 0.03, '100% Clay', fontsize=16, ha='center', weight='bold')
319
-
320
- # Add grid lines
321
- for pct in range(5, 100, 5):
322
- y = pct/100 * np.sqrt(3)/2
323
- x_left = pct/200
324
- x_right = 1 - pct/200
325
- sand_pct = pct
326
- p1 = soil_to_coords(sand_pct, 0, 100-sand_pct)
327
- p2 = soil_to_coords(sand_pct, 100-sand_pct, 0)
328
- silt_pct = pct
329
- p3 = soil_to_coords(0, silt_pct, 100-silt_pct)
330
- p4 = soil_to_coords(100-silt_pct, silt_pct, 0)
331
-
332
- if pct % 10 == 0:
333
- ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
334
- ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
335
- ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
336
- ax.text(x_left - 0.03, y, f'{pct}', fontsize=11, alpha=0.7, weight='bold')
337
- else:
338
- ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
339
- ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
340
- ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
341
-
342
- # Plot prediction point
343
- pred_x, pred_y = soil_to_coords(sand, silt, clay)
344
- ax.plot(pred_x, pred_y, 'o', markersize=22, markerfacecolor='red',
345
- markeredgecolor='darkred', markeredgewidth=3.5, zorder=5)
346
-
347
- # Add annotation
348
- offset_x = 0.15 if pred_x < 0.7 else -0.15
349
- offset_y = 0.08
350
- ax.annotate(f'{predicted_class}\n({sand:.0f}%, {silt:.0f}%, {clay:.0f}%)',
351
- xy=(pred_x, pred_y), xytext=(pred_x + offset_x, pred_y + offset_y),
352
- fontsize=14, fontweight='bold',
353
- arrowprops=dict(arrowstyle='->', lw=2.5, color='darkred'),
354
- bbox=dict(boxstyle='round,pad=0.6', facecolor='white', edgecolor='darkred', lw=2.5),
355
- ha='center', zorder=6)
356
-
357
- # Add prediction information boxes
358
- if confidence is not None and top_probs is not None:
359
- # Left box - Prediction and Composition
360
- left_text = f"Predicted Class:\n{predicted_class}\n\n"
361
- left_text += f"Confidence: {confidence*100:.1f}%\n\n"
362
- left_text += f"Composition:\n"
363
- left_text += f"Sand: {sand:.1f}%\n"
364
- left_text += f"Silt: {silt:.1f}%\n"
365
- left_text += f"Clay: {clay:.1f}%"
366
-
367
- ax.text(0.05, 0.82, left_text,
368
- fontsize=16, verticalalignment='top',
369
- bbox=dict(boxstyle='round,pad=0.9', facecolor='white',
370
- edgecolor='black', linewidth=2.5, alpha=0.95),
371
- zorder=7, family='monospace', weight='bold')
372
-
373
- # Right box - Top 5 Probabilities
374
- right_text = "Top 5 Probabilities:\n\n"
375
- for i, (cls, prob) in enumerate(top_probs[:5], 1):
376
- right_text += f"{i}. {cls}: {prob*100:.1f}%\n"
377
-
378
- ax.text(0.75, 0.82, right_text,
379
- fontsize=16, verticalalignment='top',
380
- bbox=dict(boxstyle='round,pad=0.9', facecolor='white',
381
- edgecolor='black', linewidth=2.5, alpha=0.95),
382
- zorder=7, family='monospace', weight='bold')
383
-
384
- ax.set_xlim(-0.08, 1.08)
385
- ax.set_ylim(-0.08, np.sqrt(3)/2 + 0.06)
386
- ax.set_aspect('equal')
387
- ax.axis('off')
388
- ax.set_title('USDA Soil Texture Triangle', fontsize=20, fontweight='bold', pad=8)
389
-
390
- fig.tight_layout()
391
- fig.canvas.draw()
392
- img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
393
- img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
394
- img = img[:, :, :3]
395
- plt.close(fig)
396
-
397
- return img
398
-
399
-
400
- # ============================================================================
401
- # PREDICTOR CLASS
402
- # ============================================================================
403
-
404
- def classify_from_percentages(sand: float, silt: float, clay: float) -> str:
405
- """
406
- Determine USDA texture class from Sand/Silt/Clay percentages.
407
- Uses official USDA classification boundaries.
408
- """
409
- # Normalize to ensure sum = 100
410
- total = sand + silt + clay
411
- if total > 0:
412
- sand = sand / total * 100
413
- silt = silt / total * 100
414
- clay = clay / total * 100
415
-
416
- # USDA classification rules (order matters for overlapping boundaries)
417
- if clay >= 40:
418
- if silt >= 40:
419
- return 'Silty Clay'
420
- elif sand >= 45:
421
- return 'Sandy Clay'
422
- else:
423
- return 'Clay'
424
- elif clay >= 35:
425
- if sand >= 45:
426
- return 'Sandy Clay'
427
- elif silt < 20:
428
- return 'Sandy Clay'
429
- else:
430
- return 'Clay Loam'
431
- elif clay >= 27:
432
- if sand >= 20 and sand < 45:
433
- return 'Clay Loam'
434
- elif silt >= 28 and silt < 40:
435
- return 'Clay Loam'
436
- elif silt >= 40:
437
- return 'Silty Clay Loam'
438
- else:
439
- return 'Sandy Clay Loam'
440
- elif clay >= 20:
441
- if sand >= 45:
442
- return 'Sandy Clay Loam'
443
- elif silt >= 28 and sand < 45:
444
- return 'Clay Loam'
445
- elif silt >= 50:
446
- return 'Silty Clay Loam'
447
- else:
448
- return 'Sandy Clay Loam'
449
- elif clay >= 12:
450
- if silt >= 50 and clay >= 12 and clay < 27:
451
- return 'Silt Loam'
452
- elif silt >= 50 and silt < 80:
453
- return 'Silt Loam'
454
- elif silt >= 80 and clay < 12:
455
- return 'Silt'
456
- elif sand >= 52:
457
- return 'Sandy Loam'
458
- else:
459
- return 'Loam'
460
- elif clay >= 7:
461
- if silt >= 50:
462
- return 'Silt Loam'
463
- elif silt >= 28 and silt < 50 and sand < 52:
464
- return 'Loam'
465
- else:
466
- return 'Sandy Loam'
467
- else:
468
- # clay < 7
469
- if silt >= 80:
470
- return 'Silt'
471
- elif silt >= 50:
472
- return 'Silt Loam'
473
- elif sand >= 85 and silt + 1.5 * clay < 15:
474
- return 'Sand'
475
- elif sand >= 70 and sand < 85:
476
- return 'Loamy Sand'
477
- elif sand >= 43 and sand < 52:
478
- return 'Sandy Loam' if silt < 50 else 'Silt Loam'
479
- elif sand >= 52:
480
- return 'Sandy Loam'
481
- else:
482
- return 'Loam'
483
-
484
-
485
- class SoilTexturePredictor:
486
- """
487
- Inference wrapper for soil texture prediction.
488
- """
489
-
490
- CLASSES = [
491
- 'Sand', 'Loamy Sand', 'Sandy Loam', 'Loam', 'Silt Loam', 'Silt',
492
- 'Sandy Clay Loam', 'Clay Loam', 'Silty Clay Loam', 'Sandy Clay', 'Silty Clay', 'Clay'
493
- ]
494
-
495
- def __init__(
496
- self,
497
- checkpoint_path: str = None,
498
- device: str = None,
499
- attention_type: str = "none",
500
- attention_reduction: int = 16,
501
- task_attention: bool = False,
502
- ):
503
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
504
-
505
- # Create model
506
- self.model = create_model(
507
- model_type='full',
508
- backbone='efficientnet_v2_s',
509
- num_classes=len(self.CLASSES),
510
- pretrained=False,
511
- attention_type=attention_type,
512
- attention_reduction=attention_reduction,
513
- task_attention=task_attention,
514
- )
515
-
516
- # Load checkpoint if provided
517
- if checkpoint_path and Path(checkpoint_path).exists():
518
- print(f"Loading checkpoint: {checkpoint_path}")
519
- checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
520
- if 'model_state_dict' in checkpoint:
521
- self.model.load_state_dict(checkpoint['model_state_dict'])
522
- else:
523
- self.model.load_state_dict(checkpoint)
524
- else:
525
- print("No checkpoint provided, using random weights (for demo)")
526
-
527
- self.model.to(self.device)
528
- self.model.eval()
529
-
530
- # Transform
531
- self.transform = transforms.Compose([
532
- transforms.Resize((500, 500)),
533
- transforms.ToTensor(),
534
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
535
- ])
536
-
537
- @torch.no_grad()
538
- def predict(self, image: Image.Image) -> Dict:
539
- """
540
- Predict soil texture class and concentrations.
541
- """
542
- # Preprocess
543
- img_tensor = self.transform(image).unsqueeze(0).to(self.device)
544
-
545
- # Forward pass
546
- output = self.model(img_tensor)
547
-
548
- # Class prediction from classification head (for reference)
549
- class_probs = F.softmax(output['class_logits'], dim=1).cpu().numpy()[0]
550
-
551
- # Concentration prediction
552
- concentrations = output['concentrations'].cpu().numpy()[0]
553
- sand, silt, clay = concentrations
554
-
555
- # Ensure they sum to 100
556
- total = sand + silt + clay
557
- sand = sand / total * 100
558
- silt = silt / total * 100
559
- clay = clay / total * 100
560
-
561
- # Derive class from percentages to ensure consistency
562
- class_name = classify_from_percentages(sand, silt, clay)
563
- confidence = class_probs[self.CLASSES.index(class_name)]
564
-
565
- return {
566
- 'class': class_name,
567
- 'confidence': confidence,
568
- 'class_probabilities': {self.CLASSES[i]: float(p) for i, p in enumerate(class_probs)},
569
- 'sand': sand,
570
- 'silt': silt,
571
- 'clay': clay
572
- }
573
-
574
- def predict_with_visualization(self, image: Image.Image) -> Tuple[str, np.ndarray, Dict]:
575
- """Predict and create visualization."""
576
- result = self.predict(image)
577
-
578
- # Sort by probability and show top 5
579
- sorted_probs = sorted(result['class_probabilities'].items(), key=lambda x: x[1], reverse=True)[:5]
580
-
581
- # Create texture triangle
582
- triangle_img = create_texture_triangle(
583
- result['sand'], result['silt'], result['clay'], result['class'],
584
- confidence=result['confidence'],
585
- top_probs=sorted_probs
586
- )
587
-
588
- text_output = format_prediction_markdown(result)
589
- return text_output, triangle_img, result
590
-
591
-
592
- # ============================================================================
593
- # GRADIO INTERFACE
594
- # ============================================================================
595
-
596
- def create_demo(
597
- checkpoint_path: str = None,
598
- attention_type: str = "none",
599
- attention_reduction: int = 16,
600
- task_attention: bool = False,
601
- ):
602
- """Create Gradio demo interface."""
603
-
604
- # Initialize predictor
605
- predictor = SoilTexturePredictor(
606
- checkpoint_path=checkpoint_path,
607
- attention_type=attention_type,
608
- attention_reduction=attention_reduction,
609
- task_attention=task_attention,
610
- )
611
- collection_manager = DataCollectionManager()
612
- collection_manager.ensure_storage()
613
- collection_manager.start_scheduler()
614
-
615
- def to_pil_image(image):
616
- """Convert possible Gradio image input to PIL."""
617
- if isinstance(image, Image.Image):
618
- return image.convert("RGB")
619
- if isinstance(image, np.ndarray):
620
- return Image.fromarray(image).convert("RGB")
621
- raise ValueError("Unsupported image format.")
622
-
623
- def predict_fn(image):
624
- """Gradio prediction function."""
625
- if image is None:
626
- return "Please upload an image.", None
627
-
628
- image = to_pil_image(image)
629
-
630
- # Get prediction
631
- text_output, triangle_img, _ = predictor.predict_with_visualization(image)
632
-
633
- return text_output, triangle_img
634
-
635
- def submit_contribution_fn(
636
- image,
637
- sand,
638
- silt,
639
- clay,
640
- weak_label,
641
- strong_label,
642
- sample_source,
643
- location,
644
- notes,
645
- consent
646
- ):
647
- """Persist user-contributed image + composition for future training."""
648
- if image is None:
649
- return "Submission failed: please upload a soil image."
650
-
651
- image = to_pil_image(image)
652
- validation = collection_manager.validate_submission(
653
- sand=sand,
654
- silt=silt,
655
- clay=clay,
656
- consent=consent,
657
- image=image,
658
- )
659
- if not validation.ok:
660
- return f"Submission failed: {validation.message}"
661
-
662
- prediction = predictor.predict(image)
663
- user_class = classify_from_percentages_simple(sand, silt, clay)
664
- submission_id = collection_manager.create_submission_id()
665
- save_result = collection_manager.save_submission(
666
- image=image,
667
- submission_id=submission_id,
668
- sand=sand,
669
- silt=silt,
670
- clay=clay,
671
- user_class=user_class,
672
- weak_label=weak_label,
673
- strong_label=strong_label,
674
- prediction=prediction,
675
- sample_source=sample_source,
676
- location=location,
677
- notes=notes,
678
- total=validation.total,
679
- )
680
- image_path = save_result.get("image_path", "")
681
- is_duplicate = save_result.get("is_duplicate", "0") == "1"
682
- duplicate_of_submission = save_result.get("duplicate_of_submission", "")
683
- export_bundles = collection_manager.maybe_trigger_exports()
684
- export_note = ""
685
- if export_bundles:
686
- export_note = "\n- Auto-export triggered:\n" + "\n".join([f" - `{bundle}`" for bundle in export_bundles])
687
- dedup_note = ""
688
- if is_duplicate:
689
- dedup_note = f"\n- Duplicate image detected. Reused existing sample from `{duplicate_of_submission}`."
690
-
691
- return (
692
- "### Submission Saved\n"
693
- f"- Submission ID: `{submission_id}`\n"
694
- f"- Stored image: `{image_path}`\n"
695
- f"- User label class: `{user_class}`\n"
696
- f"- Model prediction: `{prediction['class']}` ({prediction['confidence'] * 100:.2f}%)\n"
697
- f"- Weak label: `{weak_label or ''}`\n"
698
- f"- Strong label: `{strong_label or ''}`\n"
699
- "- Data was appended to `data/community_submissions/submissions.csv`.\n"
700
- "- Daily export uses background scheduler; high disk usage triggers immediate export."
701
- f"{dedup_note}"
702
- f"{export_note}"
703
- )
704
-
705
- def get_dataset_stats_fn():
706
- """Get statistics about the current dataset."""
707
- cfg = collection_manager.config
708
- num_submissions = 0
709
- if cfg.csv_path.exists():
710
- with cfg.csv_path.open("r", encoding="utf-8") as f:
711
- reader = csv.reader(f)
712
- next(reader, None)
713
- num_submissions = sum(1 for _ in reader)
714
- num_images = 0
715
- total_size_bytes = 0
716
- if cfg.images_dir.exists():
717
- for p in cfg.images_dir.iterdir():
718
- if p.is_file():
719
- num_images += 1
720
- total_size_bytes += p.stat().st_size
721
- total_size_mb = total_size_bytes / (1024 * 1024)
722
- return (
723
- f"### Dataset Statistics\n"
724
- f"- **Total submissions:** {num_submissions}\n"
725
- f"- **Total images:** {num_images}\n"
726
- f"- **Total image size:** {total_size_mb:.1f} MB\n"
727
- )
728
-
729
- def upload_dataset_fn(zip_file, upload_consent):
730
- """Process uploaded ZIP dataset with images and CSV."""
731
- if zip_file is None:
732
- return "Please upload a ZIP file."
733
- if not upload_consent:
734
- return "Please confirm consent before uploading."
735
- zip_path = zip_file if isinstance(zip_file, str) else zip_file.name
736
- if not zipfile.is_zipfile(zip_path):
737
- return "Invalid ZIP file."
738
- max_entries = 10000
739
- max_total_size = 500 * 1024 * 1024
740
- results = {"added": 0, "skipped": 0, "errors": []}
741
- try:
742
- with zipfile.ZipFile(zip_path, "r") as zf:
743
- entries = zf.infolist()
744
- if len(entries) > max_entries:
745
- return f"ZIP has too many entries ({len(entries)}). Max: {max_entries}."
746
- total_size = sum(e.file_size for e in entries)
747
- if total_size > max_total_size:
748
- return f"ZIP too large ({total_size / 1024 / 1024:.0f} MB). Max: {max_total_size // (1024 * 1024)} MB."
749
- csv_entries = [
750
- e for e in entries
751
- if e.filename.endswith(".csv") and not e.filename.startswith("__")
752
- ]
753
- if not csv_entries:
754
- return "No CSV found in ZIP. Expected CSV with columns: filename, sand, silt, clay."
755
- with zf.open(csv_entries[0]) as csv_file:
756
- content = csv_file.read().decode("utf-8")
757
- reader = csv.DictReader(io.StringIO(content))
758
- headers = set(reader.fieldnames or [])
759
- required = {"filename", "sand", "silt", "clay"}
760
- if not required.issubset(headers):
761
- return (
762
- f"CSV must have columns: {', '.join(sorted(required))}. "
763
- f"Found: {', '.join(sorted(headers))}"
764
- )
765
- for row in reader:
766
- try:
767
- fname = row["filename"].strip()
768
- sand = float(row["sand"])
769
- silt = float(row["silt"])
770
- clay = float(row["clay"])
771
- vals = [sand, silt, clay]
772
- if any(v < 0 or v > 100 for v in vals):
773
- results["errors"].append(f"{fname}: values out of range")
774
- results["skipped"] += 1
775
- continue
776
- total = sand + silt + clay
777
- if abs(total - 100.0) > 1.0:
778
- results["errors"].append(f"{fname}: sum={total:.1f}, must be ~100")
779
- results["skipped"] += 1
780
- continue
781
- matches = [e for e in entries if Path(e.filename).name == fname]
782
- if not matches:
783
- results["errors"].append(f"Image not found in ZIP: {fname}")
784
- results["skipped"] += 1
785
- continue
786
- with zf.open(matches[0]) as img_bytes:
787
- image = Image.open(img_bytes).convert("RGB")
788
- if image.width * image.height > collection_manager.config.max_image_pixels:
789
- results["errors"].append(f"{fname}: image too large")
790
- results["skipped"] += 1
791
- continue
792
- prediction = predictor.predict(image)
793
- user_class = classify_from_percentages_simple(sand, silt, clay)
794
- submission_id = collection_manager.create_submission_id()
795
- collection_manager.save_submission(
796
- image=image,
797
- submission_id=submission_id,
798
- sand=sand, silt=silt, clay=clay,
799
- user_class=user_class,
800
- weak_label=row.get("weak_label", ""),
801
- strong_label=row.get("strong_label", ""),
802
- prediction=prediction,
803
- sample_source=row.get("source", ""),
804
- location=row.get("location", ""),
805
- notes=row.get("notes", ""),
806
- total=total,
807
- )
808
- results["added"] += 1
809
- except Exception as e:
810
- results["errors"].append(f"{row.get('filename', '?')}: {e}")
811
- results["skipped"] += 1
812
- except Exception as e:
813
- return f"Failed to process ZIP: {e}"
814
- error_summary = ""
815
- if results["errors"]:
816
- shown = results["errors"][:20]
817
- error_summary = "\n\n**Errors:**\n" + "\n".join(f"- {e}" for e in shown)
818
- if len(results["errors"]) > 20:
819
- error_summary += f"\n- ... and {len(results['errors']) - 20} more"
820
- return (
821
- f"### Upload Complete\n"
822
- f"- **Added:** {results['added']} submissions\n"
823
- f"- **Skipped:** {results['skipped']}\n"
824
- f"{error_summary}"
825
- )
826
-
827
- # Create interface
828
- with gr.Blocks(title="Soil Texture Classifier") as demo:
829
- gr.Markdown("""
830
- # Soil Texture Classification
831
-
832
- 1. Use **Inference** to predict texture class and composition from image.
833
- 2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training.
834
- 3. Use **Dataset Management** to bulk-upload a ZIP dataset for model improvement.
835
- """)
836
-
837
- with gr.Tabs():
838
- with gr.Tab("Inference"):
839
- with gr.Row():
840
- with gr.Column():
841
- input_image = gr.Image(label="Upload Soil Image", type="pil")
842
- predict_btn = gr.Button("Analyze", variant="primary")
843
-
844
- gr.Markdown("""
845
- **Tips:**
846
- - Use close-up images of soil surface
847
- - Ensure good lighting
848
- - Avoid shadows and reflections
849
- """)
850
-
851
- with gr.Column():
852
- output_text = gr.Markdown(label="Results")
853
- output_triangle = gr.Image(label="USDA Texture Triangle")
854
-
855
- with gr.Tab("Contribute Data"):
856
- gr.Markdown("""
857
- Upload a soil image with measured Sand/Silt/Clay percentages.
858
- This data will be stored for manual quality checks and future retraining.
859
- You can optionally submit weak/strong labels for better curation quality.
860
- """)
861
- with gr.Row():
862
- with gr.Column():
863
- contribution_image = gr.Image(label="Soil Image for Contribution", type="pil")
864
- weak_label = gr.Dropdown(
865
- choices=[""] + SoilTexturePredictor.CLASSES,
866
- value="",
867
- allow_custom_value=True,
868
- label="Weak Label (Optional)"
869
- )
870
- strong_label = gr.Dropdown(
871
- choices=[""] + SoilTexturePredictor.CLASSES,
872
- value="",
873
- allow_custom_value=True,
874
- label="Strong Label (Optional)"
875
- )
876
- sample_source = gr.Textbox(
877
- label="Sample Source",
878
- placeholder="e.g., field site, experiment ID, sample batch"
879
- )
880
- location = gr.Textbox(
881
- label="Location (Optional)",
882
- placeholder="e.g., Iowa, USA"
883
- )
884
- notes = gr.Textbox(
885
- label="Notes (Optional)",
886
- lines=4,
887
- placeholder="Any observation, sampling method, moisture condition, etc."
888
- )
889
- with gr.Column():
890
- sand_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Sand (%)")
891
- silt_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Silt (%)")
892
- clay_input = gr.Slider(0, 100, value=33.4, step=0.1, label="Clay (%)")
893
- consent = gr.Checkbox(
894
- label="I confirm this image and labels can be used for model improvement.",
895
- value=False
896
- )
897
- submit_btn = gr.Button("Submit Contribution", variant="primary")
898
- contribution_status = gr.Markdown(label="Submission Status")
899
-
900
- with gr.Tab("Dataset Management"):
901
- gr.Markdown("""
902
- **Upload** a dataset (ZIP) to contribute bulk data for model improvement.
903
-
904
- **Upload format:** ZIP containing a CSV file and image files.
905
- CSV columns: `filename`, `sand`, `silt`, `clay` (required).
906
- Optional: `weak_label`, `strong_label`, `source`, `location`, `notes`.
907
- """)
908
- with gr.Row():
909
- with gr.Column():
910
- upload_file = gr.File(label="ZIP Dataset", file_types=[".zip"])
911
- upload_consent = gr.Checkbox(
912
- label="I confirm these images and labels can be used for model improvement.",
913
- value=False,
914
- )
915
- upload_btn = gr.Button("Upload Dataset", variant="primary")
916
- upload_status = gr.Markdown(label="Upload Status")
917
- with gr.Column():
918
- stats_btn = gr.Button("Refresh Statistics")
919
- stats_display = gr.Markdown(label="Statistics")
920
-
921
- # Event handlers
922
- predict_btn.click(
923
- fn=predict_fn,
924
- inputs=input_image,
925
- outputs=[output_text, output_triangle]
926
- )
927
-
928
- input_image.change(
929
- fn=predict_fn,
930
- inputs=input_image,
931
- outputs=[output_text, output_triangle]
932
- )
933
-
934
- submit_btn.click(
935
- fn=submit_contribution_fn,
936
- inputs=[
937
- contribution_image,
938
- sand_input,
939
- silt_input,
940
- clay_input,
941
- weak_label,
942
- strong_label,
943
- sample_source,
944
- location,
945
- notes,
946
- consent,
947
- ],
948
- outputs=[contribution_status]
949
- )
950
-
951
- upload_btn.click(
952
- fn=upload_dataset_fn,
953
- inputs=[upload_file, upload_consent],
954
- outputs=[upload_status],
955
- )
956
-
957
- stats_btn.click(
958
- fn=get_dataset_stats_fn,
959
- inputs=[],
960
- outputs=[stats_display],
961
- )
962
-
963
- return demo
964
-
965
-
966
- # ============================================================================
967
- # MAIN
968
- # ============================================================================
969
-
970
- if __name__ == "__main__":
971
- parser = argparse.ArgumentParser(description="Soil texture inference and contribution app")
972
- parser.add_argument("--checkpoint", type=str, default="finetuned_best.pth",
973
- help="Path to model checkpoint")
974
- parser.add_argument("--server_name", type=str, default="0.0.0.0",
975
- help="Gradio server host")
976
- parser.add_argument("--server_port", type=int, default=7860,
977
- help="Gradio server port")
978
- parser.add_argument("--share", action="store_true",
979
- help="Create a public share link")
980
- parser.add_argument("--attention_type", type=str, default="none", choices=["none", "se", "cbam"],
981
- help="Attention block used by inference model")
982
- parser.add_argument("--attention_reduction", type=int, default=16,
983
- help="Attention reduction ratio")
984
- parser.add_argument("--task_attention", action="store_true",
985
- help="Enable task-specific attention blocks")
986
- parser.add_argument("--allow_random_weights", action="store_true",
987
- help="Allow launching without checkpoint (debug only)")
988
- args = parser.parse_args()
989
-
990
- checkpoint_path = args.checkpoint
991
-
992
- if not Path(checkpoint_path).exists():
993
- if not args.allow_random_weights:
994
- raise FileNotFoundError(
995
- f"Checkpoint not found at {checkpoint_path}. "
996
- "Pass --allow_random_weights only for debugging."
997
- )
998
- print(f"Warning: Checkpoint not found at {checkpoint_path}")
999
- print("Running with random weights for debug purposes.")
1000
- checkpoint_path = None
1001
-
1002
- # Create and launch demo
1003
- demo = create_demo(
1004
- checkpoint_path=checkpoint_path,
1005
- attention_type=args.attention_type,
1006
- attention_reduction=args.attention_reduction,
1007
- task_attention=args.task_attention,
1008
- )
1009
- demo.launch(
1010
- server_name=args.server_name,
1011
- server_port=args.server_port,
1012
- share=args.share
1013
- )
 
 
1
+ import argparse
2
+ import csv
3
+ import io
4
+ import os
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Tuple, Dict
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torchvision import transforms
14
+ import timm
15
+ import gradio as gr
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.patches import Polygon
18
+
19
+ try:
20
+ from src.data_collection import DataCollectionManager, classify_from_percentages_simple
21
+ except ImportError:
22
+ import sys
23
+
24
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
25
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
26
+ from data_collection import DataCollectionManager, classify_from_percentages_simple
27
+
28
+
29
+ # ============================================================================
30
+ # MODEL ARCHITECTURE (Embedded)
31
+ # ============================================================================
32
+
33
+ class IdentityAttention(nn.Module):
34
+ """No-op attention block."""
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ return x
38
+
39
+
40
+ class SEFeatureAttention(nn.Module):
41
+ """Squeeze-and-Excitation style attention for vector features."""
42
+
43
+ def __init__(self, feature_dim: int, reduction: int = 16):
44
+ super().__init__()
45
+ hidden_dim = max(8, feature_dim // reduction)
46
+ self.fc = nn.Sequential(
47
+ nn.Linear(feature_dim, hidden_dim),
48
+ nn.ReLU(inplace=True),
49
+ nn.Linear(hidden_dim, feature_dim),
50
+ nn.Sigmoid(),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return x * self.fc(x)
55
+
56
+
57
+ class CBAMFeatureAttention(nn.Module):
58
+ """CBAM-inspired attention for vector features."""
59
+
60
+ def __init__(self, feature_dim: int, reduction: int = 16):
61
+ super().__init__()
62
+ hidden_dim = max(8, feature_dim // reduction)
63
+ self.mlp = nn.Sequential(
64
+ nn.Linear(feature_dim, hidden_dim),
65
+ nn.ReLU(inplace=True),
66
+ nn.Linear(hidden_dim, feature_dim),
67
+ )
68
+ self.gate = nn.Sigmoid()
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ avg_desc = self.mlp(x)
72
+ max_pool = x.max(dim=1, keepdim=True).values.expand_as(x)
73
+ max_desc = self.mlp(max_pool)
74
+ return x * self.gate(avg_desc + max_desc)
75
+
76
+
77
+ def build_attention_block(attention_type: str, feature_dim: int, reduction: int = 16) -> nn.Module:
78
+ key = (attention_type or "none").lower()
79
+ if key == "none":
80
+ return IdentityAttention()
81
+ if key == "se":
82
+ return SEFeatureAttention(feature_dim=feature_dim, reduction=reduction)
83
+ if key == "cbam":
84
+ return CBAMFeatureAttention(feature_dim=feature_dim, reduction=reduction)
85
+ raise ValueError(f"Unknown attention type: {attention_type}")
86
+
87
+
88
+ class SoilTextureModel(nn.Module):
89
+ """
90
+ Multi-task model for soil texture analysis.
91
+
92
+ Architecture:
93
+ Image -> Backbone -> Shared Features -> Classification Head -> Texture Class
94
+ -> Regression Head -> [Sand%, Silt%, Clay%]
95
+ """
96
+
97
+ BACKBONE_CONFIGS = {
98
+ 'efficientnet_v2_s': {'feature_dim': 1280, 'pretrained': 'tf_efficientnetv2_s'},
99
+ 'convnext_tiny': {'feature_dim': 768, 'pretrained': 'convnext_tiny'},
100
+ 'mobilevit_s': {'feature_dim': 640, 'pretrained': 'mobilevit_s'},
101
+ 'swin_tiny': {'feature_dim': 768, 'pretrained': 'swin_tiny_patch4_window7_224'},
102
+ 'resnet50': {'feature_dim': 2048, 'pretrained': 'resnet50'},
103
+ }
104
+
105
+ def __init__(
106
+ self,
107
+ backbone_name: str = 'efficientnet_v2_s',
108
+ num_classes: int = 12,
109
+ dropout: float = 0.3,
110
+ pretrained: bool = True,
111
+ freeze_backbone: bool = False,
112
+ attention_type: str = "none",
113
+ attention_reduction: int = 16,
114
+ task_attention: bool = False,
115
+ ):
116
+ super().__init__()
117
+
118
+ self.backbone_name = backbone_name
119
+ self.num_classes = num_classes
120
+
121
+ # Get backbone configuration
122
+ config = self.BACKBONE_CONFIGS.get(backbone_name, self.BACKBONE_CONFIGS['efficientnet_v2_s'])
123
+ feature_dim = config['feature_dim']
124
+
125
+ # Load pretrained backbone
126
+ self.backbone = timm.create_model(
127
+ config['pretrained'],
128
+ pretrained=pretrained,
129
+ num_classes=0, # Remove classifier head
130
+ global_pool='avg'
131
+ )
132
+
133
+ # Freeze backbone if specified
134
+ if freeze_backbone:
135
+ for param in self.backbone.parameters():
136
+ param.requires_grad = False
137
+
138
+ self.shared_attention = build_attention_block(
139
+ attention_type=attention_type,
140
+ feature_dim=feature_dim,
141
+ reduction=attention_reduction,
142
+ )
143
+ if task_attention:
144
+ self.class_attention = build_attention_block(
145
+ attention_type=attention_type,
146
+ feature_dim=feature_dim,
147
+ reduction=attention_reduction,
148
+ )
149
+ self.reg_attention = build_attention_block(
150
+ attention_type=attention_type,
151
+ feature_dim=feature_dim,
152
+ reduction=attention_reduction,
153
+ )
154
+ else:
155
+ self.class_attention = IdentityAttention()
156
+ self.reg_attention = IdentityAttention()
157
+
158
+ # Classification head (texture type)
159
+ self.classifier = nn.Sequential(
160
+ nn.Dropout(dropout),
161
+ nn.Linear(feature_dim, 512),
162
+ nn.BatchNorm1d(512),
163
+ nn.ReLU(inplace=True),
164
+ nn.Dropout(dropout * 0.5),
165
+ nn.Linear(512, 256),
166
+ nn.ReLU(inplace=True),
167
+ nn.Linear(256, num_classes)
168
+ )
169
+
170
+ # Regression head (Sand, Silt, Clay percentages)
171
+ self.regressor = nn.Sequential(
172
+ nn.Dropout(dropout),
173
+ nn.Linear(feature_dim, 512),
174
+ nn.BatchNorm1d(512),
175
+ nn.ReLU(inplace=True),
176
+ nn.Dropout(dropout * 0.5),
177
+ nn.Linear(512, 256),
178
+ nn.ReLU(inplace=True),
179
+ nn.Linear(256, 3) # Sand, Silt, Clay
180
+ )
181
+
182
+ # Initialize weights
183
+ self._init_weights()
184
+
185
+ def _init_weights(self):
186
+ for m in [
187
+ self.shared_attention,
188
+ self.class_attention,
189
+ self.reg_attention,
190
+ self.classifier,
191
+ self.regressor,
192
+ ]:
193
+ for layer in m.modules():
194
+ if isinstance(layer, nn.Linear):
195
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
196
+ if layer.bias is not None:
197
+ nn.init.constant_(layer.bias, 0)
198
+ elif isinstance(layer, nn.BatchNorm1d):
199
+ nn.init.constant_(layer.weight, 1)
200
+ nn.init.constant_(layer.bias, 0)
201
+
202
+ def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
203
+ """Forward pass."""
204
+ # Extract features
205
+ features = self.backbone(x)
206
+ features = self.shared_attention(features)
207
+ cls_features = self.class_attention(features)
208
+ reg_features = self.reg_attention(features)
209
+
210
+ # Classification
211
+ class_logits = self.classifier(cls_features)
212
+
213
+ # Regression (with softmax to ensure sum = 100)
214
+ reg_output = self.regressor(reg_features)
215
+ concentrations = F.softmax(reg_output, dim=1) * 100 # Scale to percentages
216
+
217
+ result = {
218
+ 'class_logits': class_logits,
219
+ 'concentrations': concentrations
220
+ }
221
+
222
+ if return_features:
223
+ result['features'] = features
224
+
225
+ return result
226
+
227
+
228
+ def create_model(
229
+ model_type: str = 'full',
230
+ backbone: str = 'efficientnet_v2_s',
231
+ num_classes: int = 12,
232
+ pretrained: bool = True,
233
+ attention_type: str = "none",
234
+ attention_reduction: int = 16,
235
+ task_attention: bool = False,
236
+ ) -> nn.Module:
237
+ """Factory function to create model."""
238
+ model = SoilTextureModel(
239
+ backbone_name=backbone,
240
+ num_classes=num_classes,
241
+ pretrained=pretrained,
242
+ attention_type=attention_type,
243
+ attention_reduction=attention_reduction,
244
+ task_attention=task_attention,
245
+ )
246
+ return model
247
+
248
+
249
+ def format_prediction_markdown(result: Dict) -> str:
250
+ """Create markdown output for inference results."""
251
+ sorted_probs = sorted(result["class_probabilities"].items(), key=lambda x: x[1], reverse=True)
252
+ lines = [
253
+ "### Prediction Result",
254
+ f"- **Texture Class:** `{result['class']}`",
255
+ f"- **Confidence:** `{result['confidence'] * 100:.2f}%`",
256
+ f"- **Sand / Silt / Clay:** `{result['sand']:.2f}% / {result['silt']:.2f}% / {result['clay']:.2f}%`",
257
+ "",
258
+ "**Top Probabilities**",
259
+ ]
260
+ for class_name, prob in sorted_probs[:5]:
261
+ lines.append(f"- {class_name}: {prob * 100:.2f}%")
262
+
263
+ return "\n".join(lines)
264
+
265
+
266
+ # ============================================================================
267
+ # SOIL TEXTURE TRIANGLE VISUALIZATION
268
+ # ============================================================================
269
+
270
+ def create_texture_triangle(sand: float, silt: float, clay: float, predicted_class: str,
271
+ confidence: float = None, top_probs: list = None) -> np.ndarray:
272
+ """
273
+ Create USDA Soil Texture Triangle visualization with correct boundaries.
274
+ """
275
+ fig, ax = plt.subplots(1, 1, figsize=(14, 12), facecolor='white', dpi=150)
276
+
277
+ # Helper function to convert soil percentages to triangle coordinates
278
+ def soil_to_coords(sand_pct, silt_pct, clay_pct):
279
+ x = silt_pct/100 + clay_pct/200
280
+ y = clay_pct/100 * np.sqrt(3)/2
281
+ return x, y
282
+
283
+ # USDA Soil Texture Triangle regions with correct boundaries
284
+ regions = [
285
+ ('Sand', [(100, 0, 0), (85, 15, 0), (90, 0, 10)], '#FFE4B5'),
286
+ ('Loamy Sand', [(85, 15, 0), (70, 30, 0), (85, 0, 15), (90, 0, 10)], '#FFDAB9'),
287
+ ('Sandy Loam', [(70, 30, 0), (50, 50, 0), (42.5, 50, 7.5), (52.5, 40, 7.5), (52.5, 27.5, 20), (80, 0, 20), (85, 0, 15)], '#F4A460'),
288
+ ('Loam', [(42.5, 50, 7.5), (22.5, 50, 27.5), (45, 27.5, 27.5), (52.5, 27.5, 20), (52.5, 40, 7.5)], '#DEB887'),
289
+ ('Silt Loam', [(50, 50, 0), (20, 80, 0), (7.5, 80, 12.5), (0, 87.5, 12.5), (0, 72.5, 27.5), (22.5, 50, 27.5)], '#D2B48C'),
290
+ ('Silt', [(20, 80, 0), (0, 100, 0), (0, 87.5, 12.5), (7.5, 80, 12.5)], '#C0C0C0'),
291
+ ('Sandy Clay Loam', [(80, 0, 20), (52.5, 27.5, 20), (45, 27.5, 27.5), (45, 20, 35), (65, 0, 35)], '#CD853F'),
292
+ ('Clay Loam', [(45, 27.5, 27.5), (20, 52.5, 27.5), (20, 40, 40), (45, 15, 40)], '#D2691E'),
293
+ ('Silty Clay Loam', [(0, 72.5, 27.5), (0, 60, 40), (20, 40, 40), (20, 52.5, 27.5)], '#B8860B'),
294
+ ('Sandy Clay', [(65, 0, 35), (45, 20, 35), (45, 0, 55)], '#A0522D'),
295
+ ('Silty Clay', [(20, 40, 40), (0, 60, 40), (0, 40, 60)], '#8B4513'),
296
+ ('Clay', [(45, 15, 40), (20, 40, 40), (0, 40, 60), (0, 0, 100), (45, 0, 55)], '#654321'),
297
+ ]
298
+
299
+ # Draw colored regions with border lines
300
+ for name, vertices_pct, color in regions:
301
+ vertices_xy = [soil_to_coords(s, si, c) for s, si, c in vertices_pct]
302
+ region_patch = Polygon(vertices_xy, facecolor=color, edgecolor='#333',
303
+ linewidth=1.2, alpha=0.8, zorder=1)
304
+ ax.add_patch(region_patch)
305
+ # Add label
306
+ center_x = np.mean([v[0] for v in vertices_xy])
307
+ center_y = np.mean([v[1] for v in vertices_xy])
308
+ ax.text(center_x, center_y, name, fontsize=12, ha='center',
309
+ va='center', weight='bold', zorder=2)
310
+
311
+ # Draw triangle outline
312
+ triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
313
+ tri_patch = Polygon(triangle, fill=False, edgecolor='black', linewidth=4, zorder=3)
314
+ ax.add_patch(tri_patch)
315
+
316
+ # Add corner labels
317
+ ax.text(0, -0.05, '100% Sand', fontsize=16, ha='center', weight='bold')
318
+ ax.text(1, -0.05, '100% Silt', fontsize=16, ha='center', weight='bold')
319
+ ax.text(0.5, np.sqrt(3)/2 + 0.03, '100% Clay', fontsize=16, ha='center', weight='bold')
320
+
321
+ # Add grid lines
322
+ for pct in range(5, 100, 5):
323
+ y = pct/100 * np.sqrt(3)/2
324
+ x_left = pct/200
325
+ x_right = 1 - pct/200
326
+ sand_pct = pct
327
+ p1 = soil_to_coords(sand_pct, 0, 100-sand_pct)
328
+ p2 = soil_to_coords(sand_pct, 100-sand_pct, 0)
329
+ silt_pct = pct
330
+ p3 = soil_to_coords(0, silt_pct, 100-silt_pct)
331
+ p4 = soil_to_coords(100-silt_pct, silt_pct, 0)
332
+
333
+ if pct % 10 == 0:
334
+ ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
335
+ ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
336
+ ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.3, linewidth=1.0, zorder=0)
337
+ ax.text(x_left - 0.03, y, f'{pct}', fontsize=11, alpha=0.7, weight='bold')
338
+ else:
339
+ ax.plot([x_left, x_right], [y, y], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
340
+ ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
341
+ ax.plot([p3[0], p4[0]], [p3[1], p4[1]], 'k-', alpha=0.15, linewidth=0.6, zorder=0)
342
+
343
+ # Plot prediction point
344
+ pred_x, pred_y = soil_to_coords(sand, silt, clay)
345
+ ax.plot(pred_x, pred_y, 'o', markersize=22, markerfacecolor='red',
346
+ markeredgecolor='darkred', markeredgewidth=3.5, zorder=5)
347
+
348
+ # Add annotation
349
+ offset_x = 0.15 if pred_x < 0.7 else -0.15
350
+ offset_y = 0.08
351
+ ax.annotate(f'{predicted_class}\n({sand:.0f}%, {silt:.0f}%, {clay:.0f}%)',
352
+ xy=(pred_x, pred_y), xytext=(pred_x + offset_x, pred_y + offset_y),
353
+ fontsize=14, fontweight='bold',
354
+ arrowprops=dict(arrowstyle='->', lw=2.5, color='darkred'),
355
+ bbox=dict(boxstyle='round,pad=0.6', facecolor='white', edgecolor='darkred', lw=2.5),
356
+ ha='center', zorder=6)
357
+
358
+ # Add prediction information boxes
359
+ if confidence is not None and top_probs is not None:
360
+ # Left box - Prediction and Composition
361
+ left_text = f"Predicted Class:\n{predicted_class}\n\n"
362
+ left_text += f"Confidence: {confidence*100:.1f}%\n\n"
363
+ left_text += f"Composition:\n"
364
+ left_text += f"Sand: {sand:.1f}%\n"
365
+ left_text += f"Silt: {silt:.1f}%\n"
366
+ left_text += f"Clay: {clay:.1f}%"
367
+
368
+ ax.text(0.05, 0.82, left_text,
369
+ fontsize=16, verticalalignment='top',
370
+ bbox=dict(boxstyle='round,pad=0.9', facecolor='white',
371
+ edgecolor='black', linewidth=2.5, alpha=0.95),
372
+ zorder=7, family='monospace', weight='bold')
373
+
374
+ # Right box - Top 5 Probabilities
375
+ right_text = "Top 5 Probabilities:\n\n"
376
+ for i, (cls, prob) in enumerate(top_probs[:5], 1):
377
+ right_text += f"{i}. {cls}: {prob*100:.1f}%\n"
378
+
379
+ ax.text(0.75, 0.82, right_text,
380
+ fontsize=16, verticalalignment='top',
381
+ bbox=dict(boxstyle='round,pad=0.9', facecolor='white',
382
+ edgecolor='black', linewidth=2.5, alpha=0.95),
383
+ zorder=7, family='monospace', weight='bold')
384
+
385
+ ax.set_xlim(-0.08, 1.08)
386
+ ax.set_ylim(-0.08, np.sqrt(3)/2 + 0.06)
387
+ ax.set_aspect('equal')
388
+ ax.axis('off')
389
+ ax.set_title('USDA Soil Texture Triangle', fontsize=20, fontweight='bold', pad=8)
390
+
391
+ fig.tight_layout()
392
+ fig.canvas.draw()
393
+ img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
394
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (4,))
395
+ img = img[:, :, :3]
396
+ plt.close(fig)
397
+
398
+ return img
399
+
400
+
401
+ # ============================================================================
402
+ # PREDICTOR CLASS
403
+ # ============================================================================
404
+
405
+ def classify_from_percentages(sand: float, silt: float, clay: float) -> str:
406
+ """
407
+ Determine USDA texture class from Sand/Silt/Clay percentages.
408
+ Uses official USDA classification boundaries.
409
+ """
410
+ # Normalize to ensure sum = 100
411
+ total = sand + silt + clay
412
+ if total > 0:
413
+ sand = sand / total * 100
414
+ silt = silt / total * 100
415
+ clay = clay / total * 100
416
+
417
+ # USDA classification rules (order matters for overlapping boundaries)
418
+ if clay >= 40:
419
+ if silt >= 40:
420
+ return 'Silty Clay'
421
+ elif sand >= 45:
422
+ return 'Sandy Clay'
423
+ else:
424
+ return 'Clay'
425
+ elif clay >= 35:
426
+ if sand >= 45:
427
+ return 'Sandy Clay'
428
+ elif silt < 20:
429
+ return 'Sandy Clay'
430
+ else:
431
+ return 'Clay Loam'
432
+ elif clay >= 27:
433
+ if sand >= 20 and sand < 45:
434
+ return 'Clay Loam'
435
+ elif silt >= 28 and silt < 40:
436
+ return 'Clay Loam'
437
+ elif silt >= 40:
438
+ return 'Silty Clay Loam'
439
+ else:
440
+ return 'Sandy Clay Loam'
441
+ elif clay >= 20:
442
+ if sand >= 45:
443
+ return 'Sandy Clay Loam'
444
+ elif silt >= 28 and sand < 45:
445
+ return 'Clay Loam'
446
+ elif silt >= 50:
447
+ return 'Silty Clay Loam'
448
+ else:
449
+ return 'Sandy Clay Loam'
450
+ elif clay >= 12:
451
+ if silt >= 50 and clay >= 12 and clay < 27:
452
+ return 'Silt Loam'
453
+ elif silt >= 50 and silt < 80:
454
+ return 'Silt Loam'
455
+ elif silt >= 80 and clay < 12:
456
+ return 'Silt'
457
+ elif sand >= 52:
458
+ return 'Sandy Loam'
459
+ else:
460
+ return 'Loam'
461
+ elif clay >= 7:
462
+ if silt >= 50:
463
+ return 'Silt Loam'
464
+ elif silt >= 28 and silt < 50 and sand < 52:
465
+ return 'Loam'
466
+ else:
467
+ return 'Sandy Loam'
468
+ else:
469
+ # clay < 7
470
+ if silt >= 80:
471
+ return 'Silt'
472
+ elif silt >= 50:
473
+ return 'Silt Loam'
474
+ elif sand >= 85 and silt + 1.5 * clay < 15:
475
+ return 'Sand'
476
+ elif sand >= 70 and sand < 85:
477
+ return 'Loamy Sand'
478
+ elif sand >= 43 and sand < 52:
479
+ return 'Sandy Loam' if silt < 50 else 'Silt Loam'
480
+ elif sand >= 52:
481
+ return 'Sandy Loam'
482
+ else:
483
+ return 'Loam'
484
+
485
+
486
+ class SoilTexturePredictor:
487
+ """
488
+ Inference wrapper for soil texture prediction.
489
+ """
490
+
491
+ CLASSES = [
492
+ 'Sand', 'Loamy Sand', 'Sandy Loam', 'Loam', 'Silt Loam', 'Silt',
493
+ 'Sandy Clay Loam', 'Clay Loam', 'Silty Clay Loam', 'Sandy Clay', 'Silty Clay', 'Clay'
494
+ ]
495
+
496
+ def __init__(
497
+ self,
498
+ checkpoint_path: str = None,
499
+ device: str = None,
500
+ attention_type: str = "none",
501
+ attention_reduction: int = 16,
502
+ task_attention: bool = False,
503
+ ):
504
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
505
+
506
+ # Create model
507
+ self.model = create_model(
508
+ model_type='full',
509
+ backbone='efficientnet_v2_s',
510
+ num_classes=len(self.CLASSES),
511
+ pretrained=False,
512
+ attention_type=attention_type,
513
+ attention_reduction=attention_reduction,
514
+ task_attention=task_attention,
515
+ )
516
+
517
+ # Load checkpoint if provided
518
+ if checkpoint_path and Path(checkpoint_path).exists():
519
+ print(f"Loading checkpoint: {checkpoint_path}")
520
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
521
+ if 'model_state_dict' in checkpoint:
522
+ self.model.load_state_dict(checkpoint['model_state_dict'])
523
+ else:
524
+ self.model.load_state_dict(checkpoint)
525
+ else:
526
+ print("No checkpoint provided, using random weights (for demo)")
527
+
528
+ self.model.to(self.device)
529
+ self.model.eval()
530
+
531
+ # Transform
532
+ self.transform = transforms.Compose([
533
+ transforms.Resize((500, 500)),
534
+ transforms.ToTensor(),
535
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
536
+ ])
537
+
538
+ @torch.no_grad()
539
+ def predict(self, image: Image.Image) -> Dict:
540
+ """
541
+ Predict soil texture class and concentrations.
542
+ """
543
+ # Preprocess
544
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
545
+
546
+ # Forward pass
547
+ output = self.model(img_tensor)
548
+
549
+ # Class prediction from classification head (for reference)
550
+ class_probs = F.softmax(output['class_logits'], dim=1).cpu().numpy()[0]
551
+
552
+ # Concentration prediction
553
+ concentrations = output['concentrations'].cpu().numpy()[0]
554
+ sand, silt, clay = concentrations
555
+
556
+ # Ensure they sum to 100
557
+ total = sand + silt + clay
558
+ sand = sand / total * 100
559
+ silt = silt / total * 100
560
+ clay = clay / total * 100
561
+
562
+ # Derive class from percentages to ensure consistency
563
+ class_name = classify_from_percentages(sand, silt, clay)
564
+ confidence = class_probs[self.CLASSES.index(class_name)]
565
+
566
+ return {
567
+ 'class': class_name,
568
+ 'confidence': confidence,
569
+ 'class_probabilities': {self.CLASSES[i]: float(p) for i, p in enumerate(class_probs)},
570
+ 'sand': sand,
571
+ 'silt': silt,
572
+ 'clay': clay
573
+ }
574
+
575
+ def predict_with_visualization(self, image: Image.Image) -> Tuple[str, np.ndarray, Dict]:
576
+ """Predict and create visualization."""
577
+ result = self.predict(image)
578
+
579
+ # Sort by probability and show top 5
580
+ sorted_probs = sorted(result['class_probabilities'].items(), key=lambda x: x[1], reverse=True)[:5]
581
+
582
+ # Create texture triangle
583
+ triangle_img = create_texture_triangle(
584
+ result['sand'], result['silt'], result['clay'], result['class'],
585
+ confidence=result['confidence'],
586
+ top_probs=sorted_probs
587
+ )
588
+
589
+ text_output = format_prediction_markdown(result)
590
+ return text_output, triangle_img, result
591
+
592
+
593
+ # ============================================================================
594
+ # GRADIO INTERFACE
595
+ # ============================================================================
596
+
597
+ def create_demo(
598
+ checkpoint_path: str = None,
599
+ attention_type: str = "none",
600
+ attention_reduction: int = 16,
601
+ task_attention: bool = False,
602
+ ):
603
+ """Create Gradio demo interface."""
604
+
605
+ # Initialize predictor
606
+ predictor = SoilTexturePredictor(
607
+ checkpoint_path=checkpoint_path,
608
+ attention_type=attention_type,
609
+ attention_reduction=attention_reduction,
610
+ task_attention=task_attention,
611
+ )
612
+ collection_manager = DataCollectionManager()
613
+ collection_manager.ensure_storage()
614
+ collection_manager.start_scheduler()
615
+
616
+ def to_pil_image(image):
617
+ """Convert possible Gradio image input to PIL."""
618
+ if isinstance(image, Image.Image):
619
+ return image.convert("RGB")
620
+ if isinstance(image, np.ndarray):
621
+ return Image.fromarray(image).convert("RGB")
622
+ raise ValueError("Unsupported image format.")
623
+
624
+ def predict_fn(image):
625
+ """Gradio prediction function."""
626
+ if image is None:
627
+ return "Please upload an image.", None
628
+
629
+ image = to_pil_image(image)
630
+
631
+ # Get prediction
632
+ text_output, triangle_img, _ = predictor.predict_with_visualization(image)
633
+
634
+ return text_output, triangle_img
635
+
636
+ def submit_contribution_fn(
637
+ image,
638
+ sand,
639
+ silt,
640
+ clay,
641
+ weak_label,
642
+ strong_label,
643
+ sample_source,
644
+ location,
645
+ notes,
646
+ consent
647
+ ):
648
+ """Persist user-contributed image + composition for future training."""
649
+ if image is None:
650
+ return "Submission failed: please upload a soil image."
651
+
652
+ image = to_pil_image(image)
653
+ validation = collection_manager.validate_submission(
654
+ sand=sand,
655
+ silt=silt,
656
+ clay=clay,
657
+ consent=consent,
658
+ image=image,
659
+ )
660
+ if not validation.ok:
661
+ return f"Submission failed: {validation.message}"
662
+
663
+ prediction = predictor.predict(image)
664
+ user_class = classify_from_percentages_simple(sand, silt, clay)
665
+ submission_id = collection_manager.create_submission_id()
666
+ save_result = collection_manager.save_submission(
667
+ image=image,
668
+ submission_id=submission_id,
669
+ sand=sand,
670
+ silt=silt,
671
+ clay=clay,
672
+ user_class=user_class,
673
+ weak_label=weak_label,
674
+ strong_label=strong_label,
675
+ prediction=prediction,
676
+ sample_source=sample_source,
677
+ location=location,
678
+ notes=notes,
679
+ total=validation.total,
680
+ )
681
+ image_path = save_result.get("image_path", "")
682
+ is_duplicate = save_result.get("is_duplicate", "0") == "1"
683
+ duplicate_of_submission = save_result.get("duplicate_of_submission", "")
684
+ export_bundles = collection_manager.maybe_trigger_exports()
685
+ export_note = ""
686
+ if export_bundles:
687
+ export_note = "\n- Auto-export triggered:\n" + "\n".join([f" - `{bundle}`" for bundle in export_bundles])
688
+ dedup_note = ""
689
+ if is_duplicate:
690
+ dedup_note = f"\n- Duplicate image detected. Reused existing sample from `{duplicate_of_submission}`."
691
+
692
+ return (
693
+ "### Submission Saved\n"
694
+ f"- Submission ID: `{submission_id}`\n"
695
+ f"- Stored image: `{image_path}`\n"
696
+ f"- User label class: `{user_class}`\n"
697
+ f"- Model prediction: `{prediction['class']}` ({prediction['confidence'] * 100:.2f}%)\n"
698
+ f"- Weak label: `{weak_label or ''}`\n"
699
+ f"- Strong label: `{strong_label or ''}`\n"
700
+ "- Data was appended to `data/community_submissions/submissions.csv`.\n"
701
+ "- Daily export uses background scheduler; high disk usage triggers immediate export."
702
+ f"{dedup_note}"
703
+ f"{export_note}"
704
+ )
705
+
706
+ def get_dataset_stats_fn():
707
+ """Get statistics about the current dataset."""
708
+ cfg = collection_manager.config
709
+ num_submissions = 0
710
+ if cfg.csv_path.exists():
711
+ with cfg.csv_path.open("r", encoding="utf-8") as f:
712
+ reader = csv.reader(f)
713
+ next(reader, None)
714
+ num_submissions = sum(1 for _ in reader)
715
+ num_images = 0
716
+ total_size_bytes = 0
717
+ if cfg.images_dir.exists():
718
+ for p in cfg.images_dir.iterdir():
719
+ if p.is_file():
720
+ num_images += 1
721
+ total_size_bytes += p.stat().st_size
722
+ total_size_mb = total_size_bytes / (1024 * 1024)
723
+ return (
724
+ f"### Dataset Statistics\n"
725
+ f"- **Total submissions:** {num_submissions}\n"
726
+ f"- **Total images:** {num_images}\n"
727
+ f"- **Total image size:** {total_size_mb:.1f} MB\n"
728
+ )
729
+
730
+ def upload_dataset_fn(zip_file, upload_consent):
731
+ """Process uploaded ZIP dataset with images and CSV."""
732
+ if zip_file is None:
733
+ return "Please upload a ZIP file."
734
+ if not upload_consent:
735
+ return "Please confirm consent before uploading."
736
+ zip_path = zip_file if isinstance(zip_file, str) else zip_file.name
737
+ if not zipfile.is_zipfile(zip_path):
738
+ return "Invalid ZIP file."
739
+ max_entries = 10000
740
+ max_total_size = 500 * 1024 * 1024
741
+ results = {"added": 0, "skipped": 0, "errors": []}
742
+ try:
743
+ with zipfile.ZipFile(zip_path, "r") as zf:
744
+ entries = zf.infolist()
745
+ if len(entries) > max_entries:
746
+ return f"ZIP has too many entries ({len(entries)}). Max: {max_entries}."
747
+ total_size = sum(e.file_size for e in entries)
748
+ if total_size > max_total_size:
749
+ return f"ZIP too large ({total_size / 1024 / 1024:.0f} MB). Max: {max_total_size // (1024 * 1024)} MB."
750
+ csv_entries = [
751
+ e for e in entries
752
+ if e.filename.endswith(".csv") and not e.filename.startswith("__")
753
+ ]
754
+ if not csv_entries:
755
+ return "No CSV found in ZIP. Expected CSV with columns: filename, sand, silt, clay."
756
+ with zf.open(csv_entries[0]) as csv_file:
757
+ content = csv_file.read().decode("utf-8")
758
+ reader = csv.DictReader(io.StringIO(content))
759
+ headers = set(reader.fieldnames or [])
760
+ required = {"filename", "sand", "silt", "clay"}
761
+ if not required.issubset(headers):
762
+ return (
763
+ f"CSV must have columns: {', '.join(sorted(required))}. "
764
+ f"Found: {', '.join(sorted(headers))}"
765
+ )
766
+ for row in reader:
767
+ try:
768
+ fname = row["filename"].strip()
769
+ sand = float(row["sand"])
770
+ silt = float(row["silt"])
771
+ clay = float(row["clay"])
772
+ vals = [sand, silt, clay]
773
+ if any(v < 0 or v > 100 for v in vals):
774
+ results["errors"].append(f"{fname}: values out of range")
775
+ results["skipped"] += 1
776
+ continue
777
+ total = sand + silt + clay
778
+ if abs(total - 100.0) > 1.0:
779
+ results["errors"].append(f"{fname}: sum={total:.1f}, must be ~100")
780
+ results["skipped"] += 1
781
+ continue
782
+ matches = [e for e in entries if Path(e.filename).name == fname]
783
+ if not matches:
784
+ results["errors"].append(f"Image not found in ZIP: {fname}")
785
+ results["skipped"] += 1
786
+ continue
787
+ with zf.open(matches[0]) as img_bytes:
788
+ image = Image.open(img_bytes).convert("RGB")
789
+ if image.width * image.height > collection_manager.config.max_image_pixels:
790
+ results["errors"].append(f"{fname}: image too large")
791
+ results["skipped"] += 1
792
+ continue
793
+ prediction = predictor.predict(image)
794
+ user_class = classify_from_percentages_simple(sand, silt, clay)
795
+ submission_id = collection_manager.create_submission_id()
796
+ collection_manager.save_submission(
797
+ image=image,
798
+ submission_id=submission_id,
799
+ sand=sand, silt=silt, clay=clay,
800
+ user_class=user_class,
801
+ weak_label=row.get("weak_label", ""),
802
+ strong_label=row.get("strong_label", ""),
803
+ prediction=prediction,
804
+ sample_source=row.get("source", ""),
805
+ location=row.get("location", ""),
806
+ notes=row.get("notes", ""),
807
+ total=total,
808
+ )
809
+ results["added"] += 1
810
+ except Exception as e:
811
+ results["errors"].append(f"{row.get('filename', '?')}: {e}")
812
+ results["skipped"] += 1
813
+ except Exception as e:
814
+ return f"Failed to process ZIP: {e}"
815
+ error_summary = ""
816
+ if results["errors"]:
817
+ shown = results["errors"][:20]
818
+ error_summary = "\n\n**Errors:**\n" + "\n".join(f"- {e}" for e in shown)
819
+ if len(results["errors"]) > 20:
820
+ error_summary += f"\n- ... and {len(results['errors']) - 20} more"
821
+ return (
822
+ f"### Upload Complete\n"
823
+ f"- **Added:** {results['added']} submissions\n"
824
+ f"- **Skipped:** {results['skipped']}\n"
825
+ f"{error_summary}"
826
+ )
827
+
828
+ # Create interface
829
+ with gr.Blocks(title="Soil Texture Classifier") as demo:
830
+ gr.Markdown("""
831
+ # Soil Texture Classification
832
+
833
+ 1. Use **Inference** to predict texture class and composition from image.
834
+ 2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training.
835
+ 3. Use **Dataset Management** to bulk-upload a ZIP dataset for model improvement.
836
+ """)
837
+
838
+ with gr.Tabs():
839
+ with gr.Tab("Inference"):
840
+ with gr.Row():
841
+ with gr.Column():
842
+ input_image = gr.Image(label="Upload Soil Image", type="pil")
843
+ predict_btn = gr.Button("Analyze", variant="primary")
844
+
845
+ gr.Markdown("""
846
+ **Tips:**
847
+ - Use close-up images of soil surface
848
+ - Ensure good lighting
849
+ - Avoid shadows and reflections
850
+ """)
851
+
852
+ with gr.Column():
853
+ output_text = gr.Markdown(label="Results")
854
+ output_triangle = gr.Image(label="USDA Texture Triangle")
855
+
856
+ with gr.Tab("Contribute Data"):
857
+ gr.Markdown("""
858
+ Upload a soil image with measured Sand/Silt/Clay percentages.
859
+ This data will be stored for manual quality checks and future retraining.
860
+ You can optionally submit weak/strong labels for better curation quality.
861
+ """)
862
+ with gr.Row():
863
+ with gr.Column():
864
+ contribution_image = gr.Image(label="Soil Image for Contribution", type="pil")
865
+ weak_label = gr.Dropdown(
866
+ choices=[""] + SoilTexturePredictor.CLASSES,
867
+ value="",
868
+ allow_custom_value=True,
869
+ label="Weak Label (Optional)"
870
+ )
871
+ strong_label = gr.Dropdown(
872
+ choices=[""] + SoilTexturePredictor.CLASSES,
873
+ value="",
874
+ allow_custom_value=True,
875
+ label="Strong Label (Optional)"
876
+ )
877
+ sample_source = gr.Textbox(
878
+ label="Sample Source",
879
+ placeholder="e.g., field site, experiment ID, sample batch"
880
+ )
881
+ location = gr.Textbox(
882
+ label="Location (Optional)",
883
+ placeholder="e.g., Iowa, USA"
884
+ )
885
+ notes = gr.Textbox(
886
+ label="Notes (Optional)",
887
+ lines=4,
888
+ placeholder="Any observation, sampling method, moisture condition, etc."
889
+ )
890
+ with gr.Column():
891
+ sand_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Sand (%)")
892
+ silt_input = gr.Slider(0, 100, value=33.3, step=0.1, label="Silt (%)")
893
+ clay_input = gr.Slider(0, 100, value=33.4, step=0.1, label="Clay (%)")
894
+ consent = gr.Checkbox(
895
+ label="I confirm this image and labels can be used for model improvement.",
896
+ value=False
897
+ )
898
+ submit_btn = gr.Button("Submit Contribution", variant="primary")
899
+ contribution_status = gr.Markdown(label="Submission Status")
900
+
901
+ with gr.Tab("Dataset Management"):
902
+ gr.Markdown("""
903
+ **Upload** a dataset (ZIP) to contribute bulk data for model improvement.
904
+
905
+ **Upload format:** ZIP containing a CSV file and image files.
906
+ CSV columns: `filename`, `sand`, `silt`, `clay` (required).
907
+ Optional: `weak_label`, `strong_label`, `source`, `location`, `notes`.
908
+ """)
909
+ with gr.Row():
910
+ with gr.Column():
911
+ upload_file = gr.File(label="ZIP Dataset", file_types=[".zip"])
912
+ upload_consent = gr.Checkbox(
913
+ label="I confirm these images and labels can be used for model improvement.",
914
+ value=False,
915
+ )
916
+ upload_btn = gr.Button("Upload Dataset", variant="primary")
917
+ upload_status = gr.Markdown(label="Upload Status")
918
+ with gr.Column():
919
+ stats_btn = gr.Button("Refresh Statistics")
920
+ stats_display = gr.Markdown(label="Statistics")
921
+
922
+ # Event handlers
923
+ predict_btn.click(
924
+ fn=predict_fn,
925
+ inputs=input_image,
926
+ outputs=[output_text, output_triangle]
927
+ )
928
+
929
+ input_image.change(
930
+ fn=predict_fn,
931
+ inputs=input_image,
932
+ outputs=[output_text, output_triangle]
933
+ )
934
+
935
+ submit_btn.click(
936
+ fn=submit_contribution_fn,
937
+ inputs=[
938
+ contribution_image,
939
+ sand_input,
940
+ silt_input,
941
+ clay_input,
942
+ weak_label,
943
+ strong_label,
944
+ sample_source,
945
+ location,
946
+ notes,
947
+ consent,
948
+ ],
949
+ outputs=[contribution_status]
950
+ )
951
+
952
+ upload_btn.click(
953
+ fn=upload_dataset_fn,
954
+ inputs=[upload_file, upload_consent],
955
+ outputs=[upload_status],
956
+ )
957
+
958
+ stats_btn.click(
959
+ fn=get_dataset_stats_fn,
960
+ inputs=[],
961
+ outputs=[stats_display],
962
+ )
963
+
964
+ return demo
965
+
966
+
967
+ # ============================================================================
968
+ # MAIN
969
+ # ============================================================================
970
+
971
+ if __name__ == "__main__":
972
+ parser = argparse.ArgumentParser(description="Soil texture inference and contribution app")
973
+ parser.add_argument("--checkpoint", type=str, default="finetuned_best.pth",
974
+ help="Path to model checkpoint")
975
+ parser.add_argument("--server_name", type=str, default="0.0.0.0",
976
+ help="Gradio server host")
977
+ parser.add_argument("--server_port", type=int, default=7860,
978
+ help="Gradio server port")
979
+ parser.add_argument("--share", action="store_true",
980
+ help="Create a public share link")
981
+ parser.add_argument("--attention_type", type=str, default="none", choices=["none", "se", "cbam"],
982
+ help="Attention block used by inference model")
983
+ parser.add_argument("--attention_reduction", type=int, default=16,
984
+ help="Attention reduction ratio")
985
+ parser.add_argument("--task_attention", action="store_true",
986
+ help="Enable task-specific attention blocks")
987
+ parser.add_argument("--allow_random_weights", action="store_true",
988
+ help="Allow launching without checkpoint (debug only)")
989
+ args = parser.parse_args()
990
+
991
+ checkpoint_path = args.checkpoint
992
+
993
+ if not Path(checkpoint_path).exists():
994
+ if not args.allow_random_weights:
995
+ raise FileNotFoundError(
996
+ f"Checkpoint not found at {checkpoint_path}. "
997
+ "Pass --allow_random_weights only for debugging."
998
+ )
999
+ print(f"Warning: Checkpoint not found at {checkpoint_path}")
1000
+ print("Running with random weights for debug purposes.")
1001
+ checkpoint_path = None
1002
+
1003
+ # Create and launch demo
1004
+ demo = create_demo(
1005
+ checkpoint_path=checkpoint_path,
1006
+ attention_type=args.attention_type,
1007
+ attention_reduction=args.attention_reduction,
1008
+ task_attention=args.task_attention,
1009
+ )
1010
+ demo.launch(
1011
+ server_name=args.server_name,
1012
+ server_port=args.server_port,
1013
+ share=args.share
1014
+ )
collection_common.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared helpers for collection/curation pipeline.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import re
9
+ from pathlib import Path
10
+ from typing import List, Optional
11
+
12
+
13
+ ALLOWED_LABEL_PRIORITY = ("strong", "weak", "user")
14
+
15
+
16
+ def parse_label_priority(value: str) -> List[str]:
17
+ """
18
+ Parse and validate comma-separated label priority list.
19
+ Returns de-duplicated values while preserving order.
20
+ """
21
+ raw_items = [item.strip() for item in str(value).split(",") if item.strip()]
22
+ if not raw_items:
23
+ raise ValueError("label priority cannot be empty")
24
+
25
+ invalid = [item for item in raw_items if item not in ALLOWED_LABEL_PRIORITY]
26
+ if invalid:
27
+ raise ValueError(f"Invalid label priority values: {invalid}")
28
+
29
+ deduped = []
30
+ seen = set()
31
+ for item in raw_items:
32
+ if item in seen:
33
+ continue
34
+ deduped.append(item)
35
+ seen.add(item)
36
+ return deduped
37
+
38
+
39
+ def safe_resolve_in_dir(base_dir: Path, filename: str) -> Optional[Path]:
40
+ """
41
+ Resolve a filename safely under base_dir.
42
+ Reject nested paths and path traversal patterns.
43
+ """
44
+ raw_name = str(filename).strip()
45
+ if not raw_name:
46
+ return None
47
+ safe_name = Path(raw_name).name
48
+ if safe_name != raw_name:
49
+ return None
50
+
51
+ root = base_dir.resolve()
52
+ candidate = (base_dir / safe_name).resolve()
53
+ if os.path.commonpath([str(root), str(candidate)]) != str(root):
54
+ return None
55
+ return candidate
56
+
57
+
58
+ def sanitize_identifier(value: str, fallback: str, max_len: int = 64) -> str:
59
+ """
60
+ Sanitize identifier for filesystem-safe filenames.
61
+ """
62
+ clean = re.sub(r"[^A-Za-z0-9_-]", "_", str(value).strip())
63
+ clean = clean[:max_len]
64
+ return clean if clean else fallback
data_collection.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Collection Pipeline
3
+ ------------------------
4
+ Collection-only module for Space uploads.
5
+ Keeps collection logic separated from model training code.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import hashlib
12
+ import io
13
+ import json
14
+ import os
15
+ import shutil
16
+ import tarfile
17
+ import threading
18
+ import time
19
+ import uuid
20
+ from contextlib import contextmanager
21
+ from dataclasses import dataclass
22
+ from datetime import datetime, timezone
23
+ from pathlib import Path
24
+ from typing import Dict, List, Optional, Tuple
25
+
26
+ from PIL import Image
27
+ from collection_common import safe_resolve_in_dir
28
+
29
+
30
+ USDA_CLASSES = [
31
+ "Sand",
32
+ "Loamy Sand",
33
+ "Sandy Loam",
34
+ "Loam",
35
+ "Silt Loam",
36
+ "Silt",
37
+ "Sandy Clay Loam",
38
+ "Clay Loam",
39
+ "Silty Clay Loam",
40
+ "Sandy Clay",
41
+ "Silty Clay",
42
+ "Clay",
43
+ ]
44
+
45
+
46
+ CONTRIBUTION_FIELDS = [
47
+ "submission_id",
48
+ "timestamp_utc",
49
+ "image_filename",
50
+ "image_sha256",
51
+ "is_duplicate",
52
+ "duplicate_of_submission",
53
+ "user_sand",
54
+ "user_silt",
55
+ "user_clay",
56
+ "user_total",
57
+ "user_class",
58
+ "weak_label",
59
+ "strong_label",
60
+ "predicted_class",
61
+ "predicted_confidence",
62
+ "pred_sand",
63
+ "pred_silt",
64
+ "pred_clay",
65
+ "sample_source",
66
+ "location",
67
+ "notes",
68
+ ]
69
+
70
+
71
+ @contextmanager
72
+ def _file_lock(lock_path: Path):
73
+ """Best-effort cross-process lock for unix-like environments."""
74
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
75
+ with lock_path.open("a+") as lock_file:
76
+ try:
77
+ import fcntl # type: ignore
78
+
79
+ fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
80
+ yield
81
+ finally:
82
+ try:
83
+ import fcntl # type: ignore
84
+
85
+ fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
86
+ except Exception:
87
+ pass
88
+
89
+
90
+ def sanitize_text(value: Optional[str], max_len: int = 500) -> str:
91
+ """Sanitize free-form user text and neutralize CSV formula injection."""
92
+ if value is None:
93
+ return ""
94
+ clean = str(value).replace("\r", " ").replace("\n", " ").strip()
95
+ clean = " ".join(clean.split())
96
+ if clean and clean[0] in ("=", "+", "-", "@"):
97
+ clean = "'" + clean
98
+ return clean[:max_len]
99
+
100
+
101
+ def normalize_optional_label(label: Optional[str]) -> str:
102
+ """Normalize optional weak/strong labels."""
103
+ clean = sanitize_text(label, max_len=64)
104
+ if not clean:
105
+ return ""
106
+
107
+ normalized = clean.lower().replace("_", " ")
108
+ class_map = {c.lower(): c for c in USDA_CLASSES}
109
+ if normalized in class_map:
110
+ return class_map[normalized]
111
+
112
+ titled = " ".join(word.capitalize() for word in normalized.split())
113
+ return titled
114
+
115
+
116
+ def encode_jpeg_bytes(image: Image.Image, quality: int = 92) -> bytes:
117
+ """Encode image to JPEG bytes once for deterministic hashing and persistence."""
118
+ buffer = io.BytesIO()
119
+ image.save(buffer, format="JPEG", quality=quality)
120
+ return buffer.getvalue()
121
+
122
+
123
+ def compute_bytes_sha256(content: bytes) -> str:
124
+ return hashlib.sha256(content).hexdigest()
125
+
126
+
127
+ @dataclass
128
+ class SubmissionValidationResult:
129
+ ok: bool
130
+ message: str
131
+ total: float
132
+
133
+
134
+ @dataclass
135
+ class DataCollectionConfig:
136
+ root_dir: Path
137
+ images_dir: Path
138
+ csv_path: Path
139
+ lock_path: Path
140
+ state_path: Path
141
+ exports_dir: Path
142
+ disk_usage_threshold_percent: float
143
+ max_image_pixels: int
144
+ min_submit_interval_sec: float
145
+ daily_export_hour_utc: int
146
+ daily_export_minute_utc: int
147
+ schedule_check_interval_sec: int
148
+ hf_dataset_repo: str
149
+ hf_export_prefix: str
150
+ storage_quota_bytes: int
151
+ deduplicate_images: bool
152
+ prune_after_export: bool
153
+ max_hash_index_entries: int
154
+
155
+ @staticmethod
156
+ def from_env() -> "DataCollectionConfig":
157
+ root = Path(os.getenv("CONTRIBUTION_DATA_DIR", "data/community_submissions"))
158
+ return DataCollectionConfig(
159
+ root_dir=root,
160
+ images_dir=root / "images",
161
+ csv_path=root / "submissions.csv",
162
+ lock_path=root / ".submission.lock",
163
+ state_path=root / "collection_state.json",
164
+ exports_dir=root / "exports",
165
+ disk_usage_threshold_percent=float(os.getenv("CONTRIBUTION_MAX_USAGE_PERCENT", "90")),
166
+ max_image_pixels=int(os.getenv("CONTRIBUTION_MAX_IMAGE_PIXELS", str(20_000_000))),
167
+ min_submit_interval_sec=float(os.getenv("CONTRIBUTION_MIN_SUBMIT_INTERVAL_SEC", "0.5")),
168
+ daily_export_hour_utc=int(os.getenv("CONTRIBUTION_DAILY_EXPORT_HOUR_UTC", "23")),
169
+ daily_export_minute_utc=int(os.getenv("CONTRIBUTION_DAILY_EXPORT_MINUTE_UTC", "50")),
170
+ schedule_check_interval_sec=int(os.getenv("CONTRIBUTION_SCHEDULE_CHECK_SEC", "60")),
171
+ hf_dataset_repo=os.getenv("HF_CONTRIB_DATASET_REPO", "").strip(),
172
+ hf_export_prefix=os.getenv("HF_CONTRIB_EXPORT_PREFIX", "space_exports").strip() or "space_exports",
173
+ storage_quota_bytes=int(os.getenv("CONTRIBUTION_STORAGE_QUOTA_BYTES", "0")),
174
+ deduplicate_images=os.getenv("CONTRIBUTION_DEDUPLICATE_IMAGES", "1").strip() != "0",
175
+ prune_after_export=os.getenv("CONTRIBUTION_PRUNE_AFTER_EXPORT", "0").strip() == "1",
176
+ max_hash_index_entries=int(os.getenv("CONTRIBUTION_MAX_HASH_INDEX_ENTRIES", "50000")),
177
+ )
178
+
179
+
180
+ class DataCollectionManager:
181
+ """Manage submission persistence and export scheduling in Space."""
182
+
183
+ def __init__(self, config: Optional[DataCollectionConfig] = None):
184
+ self.config = config or DataCollectionConfig.from_env()
185
+ self._thread: Optional[threading.Thread] = None
186
+ self._stop_event = threading.Event()
187
+ self._mem_lock = threading.Lock()
188
+ self._last_submit_ts = 0.0
189
+
190
+ def ensure_storage(self) -> None:
191
+ cfg = self.config
192
+ cfg.images_dir.mkdir(parents=True, exist_ok=True)
193
+ cfg.exports_dir.mkdir(parents=True, exist_ok=True)
194
+
195
+ if not cfg.csv_path.exists():
196
+ with _file_lock(cfg.lock_path):
197
+ if not cfg.csv_path.exists():
198
+ with cfg.csv_path.open("w", newline="", encoding="utf-8") as f:
199
+ writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
200
+ writer.writeheader()
201
+
202
+ if not cfg.state_path.exists():
203
+ self._save_state({
204
+ "last_daily_export_date": "",
205
+ "last_pressure_export_at": "",
206
+ "last_uploaded_bundle": "",
207
+ "image_hash_map": {},
208
+ })
209
+
210
+ def start_scheduler(self) -> None:
211
+ """Start background scheduler for timed export checks."""
212
+ if self._thread and self._thread.is_alive():
213
+ return
214
+
215
+ self._thread = threading.Thread(target=self._scheduler_loop, name="collection-scheduler", daemon=True)
216
+ self._thread.start()
217
+
218
+ def stop_scheduler(self) -> None:
219
+ self._stop_event.set()
220
+ if self._thread and self._thread.is_alive():
221
+ self._thread.join(timeout=2)
222
+
223
+ def validate_submission(
224
+ self,
225
+ sand: float,
226
+ silt: float,
227
+ clay: float,
228
+ consent: bool,
229
+ image: Image.Image,
230
+ ) -> SubmissionValidationResult:
231
+ if image.width * image.height > self.config.max_image_pixels:
232
+ return SubmissionValidationResult(
233
+ ok=False,
234
+ message=f"Image too large. Max pixels: {self.config.max_image_pixels}.",
235
+ total=sand + silt + clay,
236
+ )
237
+
238
+ if not consent:
239
+ return SubmissionValidationResult(ok=False, message="Consent is required.", total=sand + silt + clay)
240
+
241
+ values = [sand, silt, clay]
242
+ if any(v < 0 or v > 100 for v in values):
243
+ return SubmissionValidationResult(ok=False, message="Sand/Silt/Clay must be in [0, 100].", total=sum(values))
244
+
245
+ total = sand + silt + clay
246
+ if abs(total - 100.0) > 1.0:
247
+ return SubmissionValidationResult(
248
+ ok=False,
249
+ message=f"Sand + Silt + Clay should be close to 100 (current: {total:.2f}).",
250
+ total=total,
251
+ )
252
+
253
+ with self._mem_lock:
254
+ now_ts = time.time()
255
+ if now_ts - self._last_submit_ts < self.config.min_submit_interval_sec:
256
+ return SubmissionValidationResult(
257
+ ok=False,
258
+ message="Submission too fast. Please wait a moment and retry.",
259
+ total=total,
260
+ )
261
+ self._last_submit_ts = now_ts
262
+
263
+ return SubmissionValidationResult(ok=True, message="", total=total)
264
+
265
+ def create_submission_id(self) -> str:
266
+ return f"sub_{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}_{uuid.uuid4().hex[:8]}"
267
+
268
+ def _resolve_submission_image(
269
+ self,
270
+ submission_id: str,
271
+ encoded_image: bytes,
272
+ image_hash: str,
273
+ hash_map: Dict[str, str],
274
+ ) -> Tuple[str, Path, str, str, Dict[str, str]]:
275
+ """
276
+ Resolve image storage path with optional hash-based deduplication.
277
+ Returns image metadata and updated hash map.
278
+ """
279
+ cfg = self.config
280
+ image_filename = f"{submission_id}.jpg"
281
+ image_path = cfg.images_dir / image_filename
282
+ duplicate_of_submission = ""
283
+ is_duplicate = "0"
284
+
285
+ if cfg.deduplicate_images and image_hash in hash_map:
286
+ duplicate_of_submission = str(hash_map[image_hash]).strip()
287
+ candidate_filename = f"{duplicate_of_submission}.jpg"
288
+ candidate_path = cfg.images_dir / candidate_filename
289
+ if duplicate_of_submission and candidate_path.exists():
290
+ image_filename = candidate_filename
291
+ image_path = candidate_path
292
+ is_duplicate = "1"
293
+ return image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map
294
+
295
+ image_path.write_bytes(encoded_image)
296
+ hash_map[image_hash] = submission_id
297
+ return image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map
298
+
299
+ def _trim_hash_map(self, hash_map: Dict[str, str]) -> Dict[str, str]:
300
+ if len(hash_map) <= self.config.max_hash_index_entries:
301
+ return hash_map
302
+ trimmed_items = list(hash_map.items())[-self.config.max_hash_index_entries:]
303
+ return {k: v for k, v in trimmed_items}
304
+
305
+ def _build_submission_row(
306
+ self,
307
+ submission_id: str,
308
+ image_filename: str,
309
+ image_hash: str,
310
+ is_duplicate: str,
311
+ duplicate_of_submission: str,
312
+ sand: float,
313
+ silt: float,
314
+ clay: float,
315
+ total: float,
316
+ user_class: str,
317
+ weak_label: str,
318
+ strong_label: str,
319
+ prediction: Dict[str, float],
320
+ sample_source: str,
321
+ location: str,
322
+ notes: str,
323
+ ) -> Dict[str, str]:
324
+ return {
325
+ "submission_id": submission_id,
326
+ "timestamp_utc": datetime.now(timezone.utc).isoformat(),
327
+ "image_filename": image_filename,
328
+ "image_sha256": image_hash,
329
+ "is_duplicate": is_duplicate,
330
+ "duplicate_of_submission": duplicate_of_submission,
331
+ "user_sand": f"{sand:.4f}",
332
+ "user_silt": f"{silt:.4f}",
333
+ "user_clay": f"{clay:.4f}",
334
+ "user_total": f"{total:.4f}",
335
+ "user_class": sanitize_text(user_class, max_len=64),
336
+ "weak_label": normalize_optional_label(weak_label),
337
+ "strong_label": normalize_optional_label(strong_label),
338
+ "predicted_class": sanitize_text(str(prediction.get("class", "")), max_len=64),
339
+ "predicted_confidence": f"{float(prediction.get('confidence', 0.0)):.8f}",
340
+ "pred_sand": f"{float(prediction.get('sand', 0.0)):.4f}",
341
+ "pred_silt": f"{float(prediction.get('silt', 0.0)):.4f}",
342
+ "pred_clay": f"{float(prediction.get('clay', 0.0)):.4f}",
343
+ "sample_source": sanitize_text(sample_source),
344
+ "location": sanitize_text(location),
345
+ "notes": sanitize_text(notes, max_len=2000),
346
+ }
347
+
348
+ def _append_submission_row(self, row: Dict[str, str]) -> None:
349
+ with self.config.csv_path.open("a", newline="", encoding="utf-8") as f:
350
+ writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
351
+ writer.writerow({k: row.get(k, "") for k in CONTRIBUTION_FIELDS})
352
+
353
+ def save_submission(
354
+ self,
355
+ image: Image.Image,
356
+ submission_id: str,
357
+ sand: float,
358
+ silt: float,
359
+ clay: float,
360
+ user_class: str,
361
+ weak_label: str,
362
+ strong_label: str,
363
+ prediction: Dict[str, float],
364
+ sample_source: str,
365
+ location: str,
366
+ notes: str,
367
+ total: float,
368
+ ) -> Dict[str, str]:
369
+ cfg = self.config
370
+ self.ensure_storage()
371
+
372
+ encoded_image = encode_jpeg_bytes(image, quality=92)
373
+ image_hash = compute_bytes_sha256(encoded_image)
374
+
375
+ with _file_lock(cfg.lock_path):
376
+ state = self._load_state()
377
+ hash_map = state.get("image_hash_map", {})
378
+ if not isinstance(hash_map, dict):
379
+ hash_map = {}
380
+
381
+ image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map = self._resolve_submission_image(
382
+ submission_id=submission_id,
383
+ encoded_image=encoded_image,
384
+ image_hash=image_hash,
385
+ hash_map=hash_map,
386
+ )
387
+ hash_map = self._trim_hash_map(hash_map)
388
+ state["image_hash_map"] = hash_map
389
+
390
+ row = self._build_submission_row(
391
+ submission_id=submission_id,
392
+ image_filename=image_filename,
393
+ image_hash=image_hash,
394
+ is_duplicate=is_duplicate,
395
+ duplicate_of_submission=duplicate_of_submission,
396
+ sand=sand,
397
+ silt=silt,
398
+ clay=clay,
399
+ total=total,
400
+ user_class=user_class,
401
+ weak_label=weak_label,
402
+ strong_label=strong_label,
403
+ prediction=prediction,
404
+ sample_source=sample_source,
405
+ location=location,
406
+ notes=notes,
407
+ )
408
+ self._append_submission_row(row)
409
+ self._save_state(state)
410
+
411
+ return {
412
+ "image_path": str(image_path),
413
+ "image_filename": image_filename,
414
+ "image_sha256": image_hash,
415
+ "is_duplicate": is_duplicate,
416
+ "duplicate_of_submission": duplicate_of_submission,
417
+ }
418
+
419
+ def maybe_trigger_exports(self) -> List[Path]:
420
+ """Run daily and pressure-based export checks."""
421
+ bundles: List[Path] = []
422
+ bundles.extend(self._maybe_daily_export())
423
+ bundles.extend(self._maybe_pressure_export())
424
+ return bundles
425
+
426
+ def _scheduler_loop(self) -> None:
427
+ self.ensure_storage()
428
+ while not self._stop_event.is_set():
429
+ try:
430
+ bundles = self.maybe_trigger_exports()
431
+ if bundles:
432
+ print(f"[collection] exported {len(bundles)} bundle(s) from scheduler")
433
+ except Exception as exc:
434
+ print(f"[collection] scheduler error: {exc}")
435
+ self._stop_event.wait(self.config.schedule_check_interval_sec)
436
+
437
+ def _maybe_daily_export(self) -> List[Path]:
438
+ now = datetime.now(timezone.utc)
439
+ state = self._load_state()
440
+ last_date = state.get("last_daily_export_date", "")
441
+
442
+ if now.hour < self.config.daily_export_hour_utc:
443
+ return []
444
+ if now.hour == self.config.daily_export_hour_utc and now.minute < self.config.daily_export_minute_utc:
445
+ return []
446
+
447
+ current_date = now.strftime("%Y-%m-%d")
448
+ if last_date == current_date:
449
+ return []
450
+
451
+ bundle = self.export_date_bundle(current_date, reason="daily")
452
+ if bundle:
453
+ state["last_daily_export_date"] = current_date
454
+ self._save_state(state)
455
+ return [bundle]
456
+ return []
457
+
458
+ def _maybe_pressure_export(self) -> List[Path]:
459
+ usage = self.get_storage_usage_percent()
460
+ if usage < self.config.disk_usage_threshold_percent:
461
+ return []
462
+
463
+ now = datetime.now(timezone.utc)
464
+ state = self._load_state()
465
+ last_pressure = state.get("last_pressure_export_at", "")
466
+ if last_pressure:
467
+ try:
468
+ last_dt = datetime.fromisoformat(last_pressure)
469
+ # Avoid repeated exports in short intervals under sustained pressure.
470
+ if (now - last_dt).total_seconds() < 10 * 60:
471
+ return []
472
+ except Exception:
473
+ pass
474
+
475
+ current_date = now.strftime("%Y-%m-%d")
476
+ bundle = self.export_date_bundle(current_date, reason="pressure")
477
+ if bundle:
478
+ state["last_pressure_export_at"] = now.isoformat()
479
+ self._save_state(state)
480
+ return [bundle]
481
+ return []
482
+
483
+ def export_date_bundle(self, target_date: str, reason: str = "daily") -> Optional[Path]:
484
+ """Export one day's submissions to tar.gz and optionally upload to HF dataset."""
485
+ self.ensure_storage()
486
+ rows = self._read_rows_for_date(target_date)
487
+ if not rows:
488
+ return None
489
+
490
+ ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
491
+ bundle_name = f"submissions_{target_date}_{reason}_{ts}.tar.gz"
492
+
493
+ reason_dir = self.config.exports_dir / reason / target_date
494
+ reason_dir.mkdir(parents=True, exist_ok=True)
495
+ bundle_path = reason_dir / bundle_name
496
+
497
+ staging = self.config.root_dir / ".staging" / f"{target_date}_{reason}_{ts}"
498
+ images_staging = staging / "images"
499
+ meta_staging = staging / "metadata"
500
+ images_staging.mkdir(parents=True, exist_ok=True)
501
+ meta_staging.mkdir(parents=True, exist_ok=True)
502
+
503
+ manifest_csv = meta_staging / "submissions.csv"
504
+ exported_rows = []
505
+ with manifest_csv.open("w", newline="", encoding="utf-8") as f:
506
+ writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
507
+ writer.writeheader()
508
+ for row in rows:
509
+ raw_image_name = str(row.get("image_filename", "")).strip()
510
+ src_img = safe_resolve_in_dir(self.config.images_dir, raw_image_name)
511
+ if src_img is None or not src_img.exists():
512
+ continue
513
+
514
+ safe_image_name = Path(raw_image_name).name
515
+ safe_row = {k: row.get(k, "") for k in CONTRIBUTION_FIELDS}
516
+ safe_row["image_filename"] = safe_image_name
517
+ writer.writerow(safe_row)
518
+ exported_rows.append(safe_row)
519
+ shutil.copy2(src_img, images_staging / safe_image_name)
520
+
521
+ if not exported_rows:
522
+ shutil.rmtree(staging, ignore_errors=True)
523
+ return None
524
+
525
+ manifest_json = meta_staging / "manifest.json"
526
+ manifest_json.write_text(
527
+ json.dumps(
528
+ {
529
+ "date": target_date,
530
+ "reason": reason,
531
+ "created_at_utc": datetime.now(timezone.utc).isoformat(),
532
+ "sample_count": len(exported_rows),
533
+ "fields": CONTRIBUTION_FIELDS,
534
+ },
535
+ indent=2,
536
+ ),
537
+ encoding="utf-8",
538
+ )
539
+
540
+ with tarfile.open(bundle_path, "w:gz") as tar:
541
+ tar.add(staging, arcname=f"bundle_{target_date}_{reason}")
542
+
543
+ shutil.rmtree(staging, ignore_errors=True)
544
+
545
+ # Optional upload to HF dataset repo for local download jobs.
546
+ self._upload_bundle_to_hf(bundle_path, reason=reason, target_date=target_date)
547
+ if self.config.prune_after_export:
548
+ self._prune_rows_for_date(target_date)
549
+ return bundle_path
550
+
551
+ def get_storage_usage_percent(self) -> float:
552
+ if self.config.storage_quota_bytes > 0:
553
+ used_bytes = self._get_dir_size_bytes(self.config.root_dir)
554
+ return used_bytes * 100.0 / float(self.config.storage_quota_bytes)
555
+
556
+ usage = shutil.disk_usage(self.config.root_dir)
557
+ if usage.total <= 0:
558
+ return 0.0
559
+ return usage.used * 100.0 / usage.total
560
+
561
+ def _get_dir_size_bytes(self, path: Path) -> int:
562
+ total = 0
563
+ for item in path.rglob("*"):
564
+ if item.is_file():
565
+ try:
566
+ total += item.stat().st_size
567
+ except Exception:
568
+ pass
569
+ return total
570
+
571
+ def _read_rows_for_date(self, target_date: str) -> List[Dict[str, str]]:
572
+ rows: List[Dict[str, str]] = []
573
+ with _file_lock(self.config.lock_path):
574
+ if not self.config.csv_path.exists():
575
+ return []
576
+ with self.config.csv_path.open("r", newline="", encoding="utf-8") as f:
577
+ reader = csv.DictReader(f)
578
+ for row in reader:
579
+ ts = str(row.get("timestamp_utc", ""))
580
+ if ts.startswith(target_date):
581
+ rows.append(row)
582
+ return rows
583
+
584
+ def _load_state(self) -> Dict[str, object]:
585
+ if not self.config.state_path.exists():
586
+ return {}
587
+ try:
588
+ return json.loads(self.config.state_path.read_text(encoding="utf-8"))
589
+ except Exception:
590
+ return {}
591
+
592
+ def _save_state(self, state: Dict[str, object]) -> None:
593
+ self.config.state_path.parent.mkdir(parents=True, exist_ok=True)
594
+ self.config.state_path.write_text(json.dumps(state, indent=2), encoding="utf-8")
595
+
596
+ def _prune_rows_for_date(self, target_date: str) -> None:
597
+ """
598
+ Prune exported date rows/images from hot Space storage.
599
+ Keeps export bundles as durable transfer unit.
600
+ """
601
+ with _file_lock(self.config.lock_path):
602
+ if not self.config.csv_path.exists():
603
+ return
604
+ with self.config.csv_path.open("r", newline="", encoding="utf-8") as f:
605
+ reader = csv.DictReader(f)
606
+ all_rows = list(reader)
607
+
608
+ keep_rows = []
609
+ drop_rows = []
610
+ for row in all_rows:
611
+ ts = str(row.get("timestamp_utc", ""))
612
+ if ts.startswith(target_date):
613
+ drop_rows.append(row)
614
+ else:
615
+ keep_rows.append(row)
616
+
617
+ if not drop_rows:
618
+ return
619
+
620
+ with self.config.csv_path.open("w", newline="", encoding="utf-8") as f:
621
+ writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
622
+ writer.writeheader()
623
+ for row in keep_rows:
624
+ writer.writerow({k: row.get(k, "") for k in CONTRIBUTION_FIELDS})
625
+
626
+ # Remove unreferenced images only.
627
+ still_referenced = set()
628
+ for row in keep_rows:
629
+ image_name = str(row.get("image_filename", "")).strip()
630
+ safe_path = safe_resolve_in_dir(self.config.images_dir, image_name)
631
+ if safe_path is not None:
632
+ still_referenced.add(safe_path.name)
633
+ for row in drop_rows:
634
+ image_filename = str(row.get("image_filename", "")).strip()
635
+ image_path = safe_resolve_in_dir(self.config.images_dir, image_filename)
636
+ if image_path is None:
637
+ continue
638
+ if image_path.name in still_referenced:
639
+ continue
640
+ if image_path.exists():
641
+ try:
642
+ image_path.unlink()
643
+ except Exception:
644
+ pass
645
+
646
+ # Rebuild hash map from kept rows.
647
+ state = self._load_state()
648
+ rebuilt_hash_map = {}
649
+ for row in keep_rows:
650
+ image_hash = str(row.get("image_sha256", "")).strip()
651
+ submission_id = str(row.get("submission_id", "")).strip()
652
+ if image_hash and submission_id:
653
+ rebuilt_hash_map[image_hash] = submission_id
654
+ state["image_hash_map"] = rebuilt_hash_map
655
+ self._save_state(state)
656
+
657
+ def _upload_bundle_to_hf(self, bundle_path: Path, reason: str, target_date: str) -> None:
658
+ repo_id = self.config.hf_dataset_repo
659
+ if not repo_id:
660
+ return
661
+
662
+ try:
663
+ from huggingface_hub import HfApi # type: ignore
664
+ except Exception:
665
+ print("[collection] huggingface_hub is not installed; skip upload.")
666
+ return
667
+
668
+ try:
669
+ api = HfApi(token=os.getenv("HF_TOKEN"))
670
+ path_in_repo = f"{self.config.hf_export_prefix}/{reason}/{target_date}/{bundle_path.name}"
671
+ api.upload_file(
672
+ path_or_fileobj=str(bundle_path),
673
+ path_in_repo=path_in_repo,
674
+ repo_id=repo_id,
675
+ repo_type="dataset",
676
+ )
677
+ state = self._load_state()
678
+ state["last_uploaded_bundle"] = path_in_repo
679
+ self._save_state(state)
680
+ print(f"[collection] uploaded bundle to dataset: {repo_id}/{path_in_repo}")
681
+ except Exception as exc:
682
+ print(f"[collection] failed to upload bundle to dataset: {exc}")
683
+
684
+
685
+ def classify_from_percentages_simple(sand: float, silt: float, clay: float) -> str:
686
+ """Simple USDA class rules to label user-provided composition."""
687
+ total = sand + silt + clay
688
+ if total > 0:
689
+ sand = sand / total * 100
690
+ silt = silt / total * 100
691
+ clay = clay / total * 100
692
+
693
+ if clay >= 40:
694
+ if silt >= 40:
695
+ return "Silty Clay"
696
+ if sand >= 45:
697
+ return "Sandy Clay"
698
+ return "Clay"
699
+ if clay >= 27:
700
+ if silt >= 40:
701
+ return "Silty Clay Loam"
702
+ if sand >= 45:
703
+ return "Sandy Clay Loam"
704
+ return "Clay Loam"
705
+ if clay >= 20:
706
+ if sand >= 45:
707
+ return "Sandy Clay Loam"
708
+ if silt >= 50:
709
+ return "Silty Clay Loam"
710
+ return "Clay Loam"
711
+ if clay >= 7:
712
+ if silt >= 50:
713
+ return "Silt Loam"
714
+ if sand >= 52:
715
+ return "Sandy Loam"
716
+ return "Loam"
717
+
718
+ if silt >= 80:
719
+ return "Silt"
720
+ if sand >= 85:
721
+ return "Sand"
722
+ if sand >= 70:
723
+ return "Loamy Sand"
724
+ if sand >= 52:
725
+ return "Sandy Loam"
726
+ if silt >= 50:
727
+ return "Silt Loam"
728
+ return "Loam"
finetuned_best.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bfc8376159a777bc0c3ce9f30ec7a13c62603e20063e0f3fd5ae9af8f6273bc3
3
- size 90440015
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7baa261829c118d7de0e18ae327cb894cbac51c4c79f960266c5f77819c2d78c
3
+ size 87950355
requirements.txt CHANGED
@@ -1,18 +1,18 @@
1
- # Core ML - Required for model
2
- torch>=2.0.0
3
- torchvision>=0.15.0
4
- timm>=0.9.0
5
-
6
- # Image Processing - Required by app.py
7
- numpy>=1.24.0
8
- Pillow>=10.0.0
9
- opencv-python-headless>=4.8.0
10
-
11
- # Visualization - Required for texture triangle
12
- matplotlib>=3.7.0
13
-
14
- # WebUI - Required for Gradio interface
15
- gradio>=4.0.0
16
-
17
- # Hub sync/export (Space -> Dataset)
18
- huggingface_hub>=0.26.0
 
1
+ # Core ML - Required for model
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ timm>=0.9.0
5
+
6
+ # Image Processing - Required by app.py
7
+ numpy>=1.24.0
8
+ Pillow>=10.0.0
9
+ opencv-python-headless>=4.8.0
10
+
11
+ # Visualization - Required for texture triangle
12
+ matplotlib>=3.7.0
13
+
14
+ # WebUI - Required for Gradio interface
15
+ gradio>=4.0.0
16
+
17
+ # Hub sync/export (Space -> Dataset)
18
+ huggingface_hub>=0.26.0