Iridium-193 commited on
Commit
58969d7
·
verified ·
1 Parent(s): 3655df7

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +162 -2
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import argparse
 
 
 
 
2
  from pathlib import Path
3
  from typing import Tuple, Dict
4
  import numpy as np
@@ -17,7 +21,7 @@ try:
17
  except ImportError:
18
  import sys
19
 
20
- sys.path.append(str(Path(__file__).parent / "src"))
21
  from data_collection import DataCollectionManager, classify_from_percentages_simple
22
 
23
 
@@ -698,13 +702,136 @@ def create_demo(
698
  f"{export_note}"
699
  )
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  # Create interface
702
  with gr.Blocks(title="Soil Texture Classifier") as demo:
703
  gr.Markdown("""
704
  # Soil Texture Classification
705
 
706
- 1. Use **Inference** to predict texture class and composition from image.
707
  2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training.
 
708
  """)
709
 
710
  with gr.Tabs():
@@ -770,6 +897,27 @@ def create_demo(
770
  submit_btn = gr.Button("Submit Contribution", variant="primary")
771
  contribution_status = gr.Markdown(label="Submission Status")
772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
  # Event handlers
774
  predict_btn.click(
775
  fn=predict_fn,
@@ -800,6 +948,18 @@ def create_demo(
800
  outputs=[contribution_status]
801
  )
802
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  return demo
804
 
805
 
 
1
  import argparse
2
+ import csv
3
+ import io
4
+ import os
5
+ import zipfile
6
  from pathlib import Path
7
  from typing import Tuple, Dict
8
  import numpy as np
 
21
  except ImportError:
22
  import sys
23
 
24
+ sys.path.insert(0, str(Path(__file__).resolve().parent / "src"))
25
  from data_collection import DataCollectionManager, classify_from_percentages_simple
26
 
27
 
 
702
  f"{export_note}"
703
  )
704
 
705
+ def get_dataset_stats_fn():
706
+ """Get statistics about the current dataset."""
707
+ cfg = collection_manager.config
708
+ num_submissions = 0
709
+ if cfg.csv_path.exists():
710
+ with cfg.csv_path.open("r", encoding="utf-8") as f:
711
+ reader = csv.reader(f)
712
+ next(reader, None)
713
+ num_submissions = sum(1 for _ in reader)
714
+ num_images = 0
715
+ total_size_bytes = 0
716
+ if cfg.images_dir.exists():
717
+ for p in cfg.images_dir.iterdir():
718
+ if p.is_file():
719
+ num_images += 1
720
+ total_size_bytes += p.stat().st_size
721
+ total_size_mb = total_size_bytes / (1024 * 1024)
722
+ return (
723
+ f"### Dataset Statistics\n"
724
+ f"- **Total submissions:** {num_submissions}\n"
725
+ f"- **Total images:** {num_images}\n"
726
+ f"- **Total image size:** {total_size_mb:.1f} MB\n"
727
+ )
728
+
729
+ def upload_dataset_fn(zip_file, upload_consent):
730
+ """Process uploaded ZIP dataset with images and CSV."""
731
+ if zip_file is None:
732
+ return "Please upload a ZIP file."
733
+ if not upload_consent:
734
+ return "Please confirm consent before uploading."
735
+ zip_path = zip_file if isinstance(zip_file, str) else zip_file.name
736
+ if not zipfile.is_zipfile(zip_path):
737
+ return "Invalid ZIP file."
738
+ max_entries = 10000
739
+ max_total_size = 500 * 1024 * 1024
740
+ results = {"added": 0, "skipped": 0, "errors": []}
741
+ try:
742
+ with zipfile.ZipFile(zip_path, "r") as zf:
743
+ entries = zf.infolist()
744
+ if len(entries) > max_entries:
745
+ return f"ZIP has too many entries ({len(entries)}). Max: {max_entries}."
746
+ total_size = sum(e.file_size for e in entries)
747
+ if total_size > max_total_size:
748
+ return f"ZIP too large ({total_size / 1024 / 1024:.0f} MB). Max: {max_total_size // (1024 * 1024)} MB."
749
+ csv_entries = [
750
+ e for e in entries
751
+ if e.filename.endswith(".csv") and not e.filename.startswith("__")
752
+ ]
753
+ if not csv_entries:
754
+ return "No CSV found in ZIP. Expected CSV with columns: filename, sand, silt, clay."
755
+ with zf.open(csv_entries[0]) as csv_file:
756
+ content = csv_file.read().decode("utf-8")
757
+ reader = csv.DictReader(io.StringIO(content))
758
+ headers = set(reader.fieldnames or [])
759
+ required = {"filename", "sand", "silt", "clay"}
760
+ if not required.issubset(headers):
761
+ return (
762
+ f"CSV must have columns: {', '.join(sorted(required))}. "
763
+ f"Found: {', '.join(sorted(headers))}"
764
+ )
765
+ for row in reader:
766
+ try:
767
+ fname = row["filename"].strip()
768
+ sand = float(row["sand"])
769
+ silt = float(row["silt"])
770
+ clay = float(row["clay"])
771
+ vals = [sand, silt, clay]
772
+ if any(v < 0 or v > 100 for v in vals):
773
+ results["errors"].append(f"{fname}: values out of range")
774
+ results["skipped"] += 1
775
+ continue
776
+ total = sand + silt + clay
777
+ if abs(total - 100.0) > 1.0:
778
+ results["errors"].append(f"{fname}: sum={total:.1f}, must be ~100")
779
+ results["skipped"] += 1
780
+ continue
781
+ matches = [e for e in entries if Path(e.filename).name == fname]
782
+ if not matches:
783
+ results["errors"].append(f"Image not found in ZIP: {fname}")
784
+ results["skipped"] += 1
785
+ continue
786
+ with zf.open(matches[0]) as img_bytes:
787
+ image = Image.open(img_bytes).convert("RGB")
788
+ if image.width * image.height > collection_manager.config.max_image_pixels:
789
+ results["errors"].append(f"{fname}: image too large")
790
+ results["skipped"] += 1
791
+ continue
792
+ prediction = predictor.predict(image)
793
+ user_class = classify_from_percentages_simple(sand, silt, clay)
794
+ submission_id = collection_manager.create_submission_id()
795
+ collection_manager.save_submission(
796
+ image=image,
797
+ submission_id=submission_id,
798
+ sand=sand, silt=silt, clay=clay,
799
+ user_class=user_class,
800
+ weak_label=row.get("weak_label", ""),
801
+ strong_label=row.get("strong_label", ""),
802
+ prediction=prediction,
803
+ sample_source=row.get("source", ""),
804
+ location=row.get("location", ""),
805
+ notes=row.get("notes", ""),
806
+ total=total,
807
+ )
808
+ results["added"] += 1
809
+ except Exception as e:
810
+ results["errors"].append(f"{row.get('filename', '?')}: {e}")
811
+ results["skipped"] += 1
812
+ except Exception as e:
813
+ return f"Failed to process ZIP: {e}"
814
+ error_summary = ""
815
+ if results["errors"]:
816
+ shown = results["errors"][:20]
817
+ error_summary = "\n\n**Errors:**\n" + "\n".join(f"- {e}" for e in shown)
818
+ if len(results["errors"]) > 20:
819
+ error_summary += f"\n- ... and {len(results['errors']) - 20} more"
820
+ return (
821
+ f"### Upload Complete\n"
822
+ f"- **Added:** {results['added']} submissions\n"
823
+ f"- **Skipped:** {results['skipped']}\n"
824
+ f"{error_summary}"
825
+ )
826
+
827
  # Create interface
828
  with gr.Blocks(title="Soil Texture Classifier") as demo:
829
  gr.Markdown("""
830
  # Soil Texture Classification
831
 
832
+ 1. Use **Inference** to predict texture class and composition from image.
833
  2. Use **Contribute Data** to upload image + measured Sand/Silt/Clay for future training.
834
+ 3. Use **Dataset Management** to bulk-upload a ZIP dataset for model improvement.
835
  """)
836
 
837
  with gr.Tabs():
 
897
  submit_btn = gr.Button("Submit Contribution", variant="primary")
898
  contribution_status = gr.Markdown(label="Submission Status")
899
 
900
+ with gr.Tab("Dataset Management"):
901
+ gr.Markdown("""
902
+ **Upload** a dataset (ZIP) to contribute bulk data for model improvement.
903
+
904
+ **Upload format:** ZIP containing a CSV file and image files.
905
+ CSV columns: `filename`, `sand`, `silt`, `clay` (required).
906
+ Optional: `weak_label`, `strong_label`, `source`, `location`, `notes`.
907
+ """)
908
+ with gr.Row():
909
+ with gr.Column():
910
+ upload_file = gr.File(label="ZIP Dataset", file_types=[".zip"])
911
+ upload_consent = gr.Checkbox(
912
+ label="I confirm these images and labels can be used for model improvement.",
913
+ value=False,
914
+ )
915
+ upload_btn = gr.Button("Upload Dataset", variant="primary")
916
+ upload_status = gr.Markdown(label="Upload Status")
917
+ with gr.Column():
918
+ stats_btn = gr.Button("Refresh Statistics")
919
+ stats_display = gr.Markdown(label="Statistics")
920
+
921
  # Event handlers
922
  predict_btn.click(
923
  fn=predict_fn,
 
948
  outputs=[contribution_status]
949
  )
950
 
951
+ upload_btn.click(
952
+ fn=upload_dataset_fn,
953
+ inputs=[upload_file, upload_consent],
954
+ outputs=[upload_status],
955
+ )
956
+
957
+ stats_btn.click(
958
+ fn=get_dataset_stats_fn,
959
+ inputs=[],
960
+ outputs=[stats_display],
961
+ )
962
+
963
  return demo
964
 
965