SoilTextureClassification / data_collection.py
Iridium-193's picture
Upload folder using huggingface_hub
49dd243 verified
"""
Data Collection Pipeline
------------------------
Collection-only module for Space uploads.
Keeps collection logic separated from model training code.
"""
from __future__ import annotations
import csv
import hashlib
import io
import json
import os
import shutil
import tarfile
import threading
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from PIL import Image
from collection_common import safe_resolve_in_dir
USDA_CLASSES = [
"Sand",
"Loamy Sand",
"Sandy Loam",
"Loam",
"Silt Loam",
"Silt",
"Sandy Clay Loam",
"Clay Loam",
"Silty Clay Loam",
"Sandy Clay",
"Silty Clay",
"Clay",
]
CONTRIBUTION_FIELDS = [
"submission_id",
"timestamp_utc",
"image_filename",
"image_sha256",
"is_duplicate",
"duplicate_of_submission",
"user_sand",
"user_silt",
"user_clay",
"user_total",
"user_class",
"weak_label",
"strong_label",
"predicted_class",
"predicted_confidence",
"pred_sand",
"pred_silt",
"pred_clay",
"sample_source",
"location",
"notes",
]
@contextmanager
def _file_lock(lock_path: Path):
"""Best-effort cross-process lock for unix-like environments."""
lock_path.parent.mkdir(parents=True, exist_ok=True)
with lock_path.open("a+") as lock_file:
try:
import fcntl # type: ignore
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
yield
finally:
try:
import fcntl # type: ignore
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
except Exception:
pass
def sanitize_text(value: Optional[str], max_len: int = 500) -> str:
"""Sanitize free-form user text and neutralize CSV formula injection."""
if value is None:
return ""
clean = str(value).replace("\r", " ").replace("\n", " ").strip()
clean = " ".join(clean.split())
if clean and clean[0] in ("=", "+", "-", "@"):
clean = "'" + clean
return clean[:max_len]
def normalize_optional_label(label: Optional[str]) -> str:
"""Normalize optional weak/strong labels."""
clean = sanitize_text(label, max_len=64)
if not clean:
return ""
normalized = clean.lower().replace("_", " ")
class_map = {c.lower(): c for c in USDA_CLASSES}
if normalized in class_map:
return class_map[normalized]
titled = " ".join(word.capitalize() for word in normalized.split())
return titled
def encode_jpeg_bytes(image: Image.Image, quality: int = 92) -> bytes:
"""Encode image to JPEG bytes once for deterministic hashing and persistence."""
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=quality)
return buffer.getvalue()
def compute_bytes_sha256(content: bytes) -> str:
return hashlib.sha256(content).hexdigest()
@dataclass
class SubmissionValidationResult:
ok: bool
message: str
total: float
@dataclass
class DataCollectionConfig:
root_dir: Path
images_dir: Path
csv_path: Path
lock_path: Path
state_path: Path
exports_dir: Path
disk_usage_threshold_percent: float
max_image_pixels: int
min_submit_interval_sec: float
daily_export_hour_utc: int
daily_export_minute_utc: int
schedule_check_interval_sec: int
hf_dataset_repo: str
hf_export_prefix: str
storage_quota_bytes: int
deduplicate_images: bool
prune_after_export: bool
max_hash_index_entries: int
@staticmethod
def from_env() -> "DataCollectionConfig":
root = Path(os.getenv("CONTRIBUTION_DATA_DIR", "data/community_submissions"))
return DataCollectionConfig(
root_dir=root,
images_dir=root / "images",
csv_path=root / "submissions.csv",
lock_path=root / ".submission.lock",
state_path=root / "collection_state.json",
exports_dir=root / "exports",
disk_usage_threshold_percent=float(os.getenv("CONTRIBUTION_MAX_USAGE_PERCENT", "90")),
max_image_pixels=int(os.getenv("CONTRIBUTION_MAX_IMAGE_PIXELS", str(20_000_000))),
min_submit_interval_sec=float(os.getenv("CONTRIBUTION_MIN_SUBMIT_INTERVAL_SEC", "0.5")),
daily_export_hour_utc=int(os.getenv("CONTRIBUTION_DAILY_EXPORT_HOUR_UTC", "23")),
daily_export_minute_utc=int(os.getenv("CONTRIBUTION_DAILY_EXPORT_MINUTE_UTC", "50")),
schedule_check_interval_sec=int(os.getenv("CONTRIBUTION_SCHEDULE_CHECK_SEC", "60")),
hf_dataset_repo=os.getenv("HF_CONTRIB_DATASET_REPO", "").strip(),
hf_export_prefix=os.getenv("HF_CONTRIB_EXPORT_PREFIX", "space_exports").strip() or "space_exports",
storage_quota_bytes=int(os.getenv("CONTRIBUTION_STORAGE_QUOTA_BYTES", "0")),
deduplicate_images=os.getenv("CONTRIBUTION_DEDUPLICATE_IMAGES", "1").strip() != "0",
prune_after_export=os.getenv("CONTRIBUTION_PRUNE_AFTER_EXPORT", "0").strip() == "1",
max_hash_index_entries=int(os.getenv("CONTRIBUTION_MAX_HASH_INDEX_ENTRIES", "50000")),
)
class DataCollectionManager:
"""Manage submission persistence and export scheduling in Space."""
def __init__(self, config: Optional[DataCollectionConfig] = None):
self.config = config or DataCollectionConfig.from_env()
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._mem_lock = threading.Lock()
self._last_submit_ts = 0.0
def ensure_storage(self) -> None:
cfg = self.config
cfg.images_dir.mkdir(parents=True, exist_ok=True)
cfg.exports_dir.mkdir(parents=True, exist_ok=True)
if not cfg.csv_path.exists():
with _file_lock(cfg.lock_path):
if not cfg.csv_path.exists():
with cfg.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
writer.writeheader()
if not cfg.state_path.exists():
self._save_state({
"last_daily_export_date": "",
"last_pressure_export_at": "",
"last_uploaded_bundle": "",
"image_hash_map": {},
})
def start_scheduler(self) -> None:
"""Start background scheduler for timed export checks."""
if self._thread and self._thread.is_alive():
return
self._thread = threading.Thread(target=self._scheduler_loop, name="collection-scheduler", daemon=True)
self._thread.start()
def stop_scheduler(self) -> None:
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2)
def validate_submission(
self,
sand: float,
silt: float,
clay: float,
consent: bool,
image: Image.Image,
) -> SubmissionValidationResult:
if image.width * image.height > self.config.max_image_pixels:
return SubmissionValidationResult(
ok=False,
message=f"Image too large. Max pixels: {self.config.max_image_pixels}.",
total=sand + silt + clay,
)
if not consent:
return SubmissionValidationResult(ok=False, message="Consent is required.", total=sand + silt + clay)
values = [sand, silt, clay]
if any(v < 0 or v > 100 for v in values):
return SubmissionValidationResult(ok=False, message="Sand/Silt/Clay must be in [0, 100].", total=sum(values))
total = sand + silt + clay
if abs(total - 100.0) > 1.0:
return SubmissionValidationResult(
ok=False,
message=f"Sand + Silt + Clay should be close to 100 (current: {total:.2f}).",
total=total,
)
with self._mem_lock:
now_ts = time.time()
if now_ts - self._last_submit_ts < self.config.min_submit_interval_sec:
return SubmissionValidationResult(
ok=False,
message="Submission too fast. Please wait a moment and retry.",
total=total,
)
self._last_submit_ts = now_ts
return SubmissionValidationResult(ok=True, message="", total=total)
def create_submission_id(self) -> str:
return f"sub_{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')}_{uuid.uuid4().hex[:8]}"
def _resolve_submission_image(
self,
submission_id: str,
encoded_image: bytes,
image_hash: str,
hash_map: Dict[str, str],
) -> Tuple[str, Path, str, str, Dict[str, str]]:
"""
Resolve image storage path with optional hash-based deduplication.
Returns image metadata and updated hash map.
"""
cfg = self.config
image_filename = f"{submission_id}.jpg"
image_path = cfg.images_dir / image_filename
duplicate_of_submission = ""
is_duplicate = "0"
if cfg.deduplicate_images and image_hash in hash_map:
duplicate_of_submission = str(hash_map[image_hash]).strip()
candidate_filename = f"{duplicate_of_submission}.jpg"
candidate_path = cfg.images_dir / candidate_filename
if duplicate_of_submission and candidate_path.exists():
image_filename = candidate_filename
image_path = candidate_path
is_duplicate = "1"
return image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map
image_path.write_bytes(encoded_image)
hash_map[image_hash] = submission_id
return image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map
def _trim_hash_map(self, hash_map: Dict[str, str]) -> Dict[str, str]:
if len(hash_map) <= self.config.max_hash_index_entries:
return hash_map
trimmed_items = list(hash_map.items())[-self.config.max_hash_index_entries:]
return {k: v for k, v in trimmed_items}
def _build_submission_row(
self,
submission_id: str,
image_filename: str,
image_hash: str,
is_duplicate: str,
duplicate_of_submission: str,
sand: float,
silt: float,
clay: float,
total: float,
user_class: str,
weak_label: str,
strong_label: str,
prediction: Dict[str, float],
sample_source: str,
location: str,
notes: str,
) -> Dict[str, str]:
return {
"submission_id": submission_id,
"timestamp_utc": datetime.now(timezone.utc).isoformat(),
"image_filename": image_filename,
"image_sha256": image_hash,
"is_duplicate": is_duplicate,
"duplicate_of_submission": duplicate_of_submission,
"user_sand": f"{sand:.4f}",
"user_silt": f"{silt:.4f}",
"user_clay": f"{clay:.4f}",
"user_total": f"{total:.4f}",
"user_class": sanitize_text(user_class, max_len=64),
"weak_label": normalize_optional_label(weak_label),
"strong_label": normalize_optional_label(strong_label),
"predicted_class": sanitize_text(str(prediction.get("class", "")), max_len=64),
"predicted_confidence": f"{float(prediction.get('confidence', 0.0)):.8f}",
"pred_sand": f"{float(prediction.get('sand', 0.0)):.4f}",
"pred_silt": f"{float(prediction.get('silt', 0.0)):.4f}",
"pred_clay": f"{float(prediction.get('clay', 0.0)):.4f}",
"sample_source": sanitize_text(sample_source),
"location": sanitize_text(location),
"notes": sanitize_text(notes, max_len=2000),
}
def _append_submission_row(self, row: Dict[str, str]) -> None:
with self.config.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
writer.writerow({k: row.get(k, "") for k in CONTRIBUTION_FIELDS})
def save_submission(
self,
image: Image.Image,
submission_id: str,
sand: float,
silt: float,
clay: float,
user_class: str,
weak_label: str,
strong_label: str,
prediction: Dict[str, float],
sample_source: str,
location: str,
notes: str,
total: float,
) -> Dict[str, str]:
cfg = self.config
self.ensure_storage()
encoded_image = encode_jpeg_bytes(image, quality=92)
image_hash = compute_bytes_sha256(encoded_image)
with _file_lock(cfg.lock_path):
state = self._load_state()
hash_map = state.get("image_hash_map", {})
if not isinstance(hash_map, dict):
hash_map = {}
image_filename, image_path, is_duplicate, duplicate_of_submission, hash_map = self._resolve_submission_image(
submission_id=submission_id,
encoded_image=encoded_image,
image_hash=image_hash,
hash_map=hash_map,
)
hash_map = self._trim_hash_map(hash_map)
state["image_hash_map"] = hash_map
row = self._build_submission_row(
submission_id=submission_id,
image_filename=image_filename,
image_hash=image_hash,
is_duplicate=is_duplicate,
duplicate_of_submission=duplicate_of_submission,
sand=sand,
silt=silt,
clay=clay,
total=total,
user_class=user_class,
weak_label=weak_label,
strong_label=strong_label,
prediction=prediction,
sample_source=sample_source,
location=location,
notes=notes,
)
self._append_submission_row(row)
self._save_state(state)
return {
"image_path": str(image_path),
"image_filename": image_filename,
"image_sha256": image_hash,
"is_duplicate": is_duplicate,
"duplicate_of_submission": duplicate_of_submission,
}
def maybe_trigger_exports(self) -> List[Path]:
"""Run daily and pressure-based export checks."""
bundles: List[Path] = []
bundles.extend(self._maybe_daily_export())
bundles.extend(self._maybe_pressure_export())
return bundles
def _scheduler_loop(self) -> None:
self.ensure_storage()
while not self._stop_event.is_set():
try:
bundles = self.maybe_trigger_exports()
if bundles:
print(f"[collection] exported {len(bundles)} bundle(s) from scheduler")
except Exception as exc:
print(f"[collection] scheduler error: {exc}")
self._stop_event.wait(self.config.schedule_check_interval_sec)
def _maybe_daily_export(self) -> List[Path]:
now = datetime.now(timezone.utc)
state = self._load_state()
last_date = state.get("last_daily_export_date", "")
if now.hour < self.config.daily_export_hour_utc:
return []
if now.hour == self.config.daily_export_hour_utc and now.minute < self.config.daily_export_minute_utc:
return []
current_date = now.strftime("%Y-%m-%d")
if last_date == current_date:
return []
bundle = self.export_date_bundle(current_date, reason="daily")
if bundle:
state["last_daily_export_date"] = current_date
self._save_state(state)
return [bundle]
return []
def _maybe_pressure_export(self) -> List[Path]:
usage = self.get_storage_usage_percent()
if usage < self.config.disk_usage_threshold_percent:
return []
now = datetime.now(timezone.utc)
state = self._load_state()
last_pressure = state.get("last_pressure_export_at", "")
if last_pressure:
try:
last_dt = datetime.fromisoformat(last_pressure)
# Avoid repeated exports in short intervals under sustained pressure.
if (now - last_dt).total_seconds() < 10 * 60:
return []
except Exception:
pass
current_date = now.strftime("%Y-%m-%d")
bundle = self.export_date_bundle(current_date, reason="pressure")
if bundle:
state["last_pressure_export_at"] = now.isoformat()
self._save_state(state)
return [bundle]
return []
def export_date_bundle(self, target_date: str, reason: str = "daily") -> Optional[Path]:
"""Export one day's submissions to tar.gz and optionally upload to HF dataset."""
self.ensure_storage()
rows = self._read_rows_for_date(target_date)
if not rows:
return None
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
bundle_name = f"submissions_{target_date}_{reason}_{ts}.tar.gz"
reason_dir = self.config.exports_dir / reason / target_date
reason_dir.mkdir(parents=True, exist_ok=True)
bundle_path = reason_dir / bundle_name
staging = self.config.root_dir / ".staging" / f"{target_date}_{reason}_{ts}"
images_staging = staging / "images"
meta_staging = staging / "metadata"
images_staging.mkdir(parents=True, exist_ok=True)
meta_staging.mkdir(parents=True, exist_ok=True)
manifest_csv = meta_staging / "submissions.csv"
exported_rows = []
with manifest_csv.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
writer.writeheader()
for row in rows:
raw_image_name = str(row.get("image_filename", "")).strip()
src_img = safe_resolve_in_dir(self.config.images_dir, raw_image_name)
if src_img is None or not src_img.exists():
continue
safe_image_name = Path(raw_image_name).name
safe_row = {k: row.get(k, "") for k in CONTRIBUTION_FIELDS}
safe_row["image_filename"] = safe_image_name
writer.writerow(safe_row)
exported_rows.append(safe_row)
shutil.copy2(src_img, images_staging / safe_image_name)
if not exported_rows:
shutil.rmtree(staging, ignore_errors=True)
return None
manifest_json = meta_staging / "manifest.json"
manifest_json.write_text(
json.dumps(
{
"date": target_date,
"reason": reason,
"created_at_utc": datetime.now(timezone.utc).isoformat(),
"sample_count": len(exported_rows),
"fields": CONTRIBUTION_FIELDS,
},
indent=2,
),
encoding="utf-8",
)
with tarfile.open(bundle_path, "w:gz") as tar:
tar.add(staging, arcname=f"bundle_{target_date}_{reason}")
shutil.rmtree(staging, ignore_errors=True)
# Optional upload to HF dataset repo for local download jobs.
self._upload_bundle_to_hf(bundle_path, reason=reason, target_date=target_date)
if self.config.prune_after_export:
self._prune_rows_for_date(target_date)
return bundle_path
def get_storage_usage_percent(self) -> float:
if self.config.storage_quota_bytes > 0:
used_bytes = self._get_dir_size_bytes(self.config.root_dir)
return used_bytes * 100.0 / float(self.config.storage_quota_bytes)
usage = shutil.disk_usage(self.config.root_dir)
if usage.total <= 0:
return 0.0
return usage.used * 100.0 / usage.total
def _get_dir_size_bytes(self, path: Path) -> int:
total = 0
for item in path.rglob("*"):
if item.is_file():
try:
total += item.stat().st_size
except Exception:
pass
return total
def _read_rows_for_date(self, target_date: str) -> List[Dict[str, str]]:
rows: List[Dict[str, str]] = []
with _file_lock(self.config.lock_path):
if not self.config.csv_path.exists():
return []
with self.config.csv_path.open("r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
ts = str(row.get("timestamp_utc", ""))
if ts.startswith(target_date):
rows.append(row)
return rows
def _load_state(self) -> Dict[str, object]:
if not self.config.state_path.exists():
return {}
try:
return json.loads(self.config.state_path.read_text(encoding="utf-8"))
except Exception:
return {}
def _save_state(self, state: Dict[str, object]) -> None:
self.config.state_path.parent.mkdir(parents=True, exist_ok=True)
self.config.state_path.write_text(json.dumps(state, indent=2), encoding="utf-8")
def _prune_rows_for_date(self, target_date: str) -> None:
"""
Prune exported date rows/images from hot Space storage.
Keeps export bundles as durable transfer unit.
"""
with _file_lock(self.config.lock_path):
if not self.config.csv_path.exists():
return
with self.config.csv_path.open("r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
all_rows = list(reader)
keep_rows = []
drop_rows = []
for row in all_rows:
ts = str(row.get("timestamp_utc", ""))
if ts.startswith(target_date):
drop_rows.append(row)
else:
keep_rows.append(row)
if not drop_rows:
return
with self.config.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=CONTRIBUTION_FIELDS)
writer.writeheader()
for row in keep_rows:
writer.writerow({k: row.get(k, "") for k in CONTRIBUTION_FIELDS})
# Remove unreferenced images only.
still_referenced = set()
for row in keep_rows:
image_name = str(row.get("image_filename", "")).strip()
safe_path = safe_resolve_in_dir(self.config.images_dir, image_name)
if safe_path is not None:
still_referenced.add(safe_path.name)
for row in drop_rows:
image_filename = str(row.get("image_filename", "")).strip()
image_path = safe_resolve_in_dir(self.config.images_dir, image_filename)
if image_path is None:
continue
if image_path.name in still_referenced:
continue
if image_path.exists():
try:
image_path.unlink()
except Exception:
pass
# Rebuild hash map from kept rows.
state = self._load_state()
rebuilt_hash_map = {}
for row in keep_rows:
image_hash = str(row.get("image_sha256", "")).strip()
submission_id = str(row.get("submission_id", "")).strip()
if image_hash and submission_id:
rebuilt_hash_map[image_hash] = submission_id
state["image_hash_map"] = rebuilt_hash_map
self._save_state(state)
def _upload_bundle_to_hf(self, bundle_path: Path, reason: str, target_date: str) -> None:
repo_id = self.config.hf_dataset_repo
if not repo_id:
return
try:
from huggingface_hub import HfApi # type: ignore
except Exception:
print("[collection] huggingface_hub is not installed; skip upload.")
return
try:
api = HfApi(token=os.getenv("HF_TOKEN"))
path_in_repo = f"{self.config.hf_export_prefix}/{reason}/{target_date}/{bundle_path.name}"
api.upload_file(
path_or_fileobj=str(bundle_path),
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
)
state = self._load_state()
state["last_uploaded_bundle"] = path_in_repo
self._save_state(state)
print(f"[collection] uploaded bundle to dataset: {repo_id}/{path_in_repo}")
except Exception as exc:
print(f"[collection] failed to upload bundle to dataset: {exc}")
def classify_from_percentages_simple(sand: float, silt: float, clay: float) -> str:
"""Simple USDA class rules to label user-provided composition."""
total = sand + silt + clay
if total > 0:
sand = sand / total * 100
silt = silt / total * 100
clay = clay / total * 100
if clay >= 40:
if silt >= 40:
return "Silty Clay"
if sand >= 45:
return "Sandy Clay"
return "Clay"
if clay >= 27:
if silt >= 40:
return "Silty Clay Loam"
if sand >= 45:
return "Sandy Clay Loam"
return "Clay Loam"
if clay >= 20:
if sand >= 45:
return "Sandy Clay Loam"
if silt >= 50:
return "Silty Clay Loam"
return "Clay Loam"
if clay >= 7:
if silt >= 50:
return "Silt Loam"
if sand >= 52:
return "Sandy Loam"
return "Loam"
if silt >= 80:
return "Silt"
if sand >= 85:
return "Sand"
if sand >= 70:
return "Loamy Sand"
if sand >= 52:
return "Sandy Loam"
if silt >= 50:
return "Silt Loam"
return "Loam"