Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| import sqlite3 | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable, Optional | |
| logger = logging.getLogger("labeler") | |
| class LabelInput: | |
| image_path: str | |
| image_abspath: Optional[Path] | |
| image_url: Optional[str] | |
| image_sha256: Optional[str] | |
| item_name: str | |
| item_description: Optional[str] | |
| item_part: Optional[str] | |
| source_type: str | |
| ocid: Optional[str] | |
| ranking: Optional[int] | |
| def iter_inputs( | |
| *, | |
| input_path: Optional[Path], | |
| db_path: Optional[Path], | |
| only_source: str, | |
| max_samples: Optional[int], | |
| run_id: Optional[str], | |
| ) -> Iterable[LabelInput]: | |
| if input_path: | |
| if input_path.suffix.lower() in {".jsonl", ".json"}: | |
| yield from _iter_manifest_jsonl(input_path, only_source, max_samples) | |
| elif input_path.suffix.lower() == ".parquet": | |
| yield from _iter_manifest_parquet(input_path, only_source, max_samples) | |
| else: | |
| raise ValueError(f"Unsupported input format: {input_path}") | |
| return | |
| if not db_path: | |
| raise ValueError("Provide --input or --db") | |
| yield from _iter_db(db_path, only_source, max_samples, run_id) | |
| def _iter_manifest_jsonl( | |
| path: Path, | |
| only_source: str, | |
| max_samples: Optional[int], | |
| ) -> Iterable[LabelInput]: | |
| base_dir = path.parent | |
| count = 0 | |
| with path.open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| record = json.loads(line) | |
| except json.JSONDecodeError: | |
| logger.warning("Skipping invalid JSON line in %s", path) | |
| continue | |
| sample = _build_from_record(record, base_dir, only_source) | |
| if not sample: | |
| continue | |
| yield sample | |
| count += 1 | |
| if max_samples and count >= max_samples: | |
| break | |
| def _iter_manifest_parquet( | |
| path: Path, | |
| only_source: str, | |
| max_samples: Optional[int], | |
| ) -> Iterable[LabelInput]: | |
| try: | |
| import pyarrow.parquet as pq | |
| except ImportError as exc: | |
| raise RuntimeError("pyarrow is required for parquet input") from exc | |
| base_dir = path.parent | |
| table = pq.read_table(path) | |
| rows = table.to_pylist() | |
| count = 0 | |
| for record in rows: | |
| sample = _build_from_record(record, base_dir, only_source) | |
| if not sample: | |
| continue | |
| yield sample | |
| count += 1 | |
| if max_samples and count >= max_samples: | |
| break | |
| def _build_from_record( | |
| record: dict[str, object], | |
| base_dir: Path, | |
| only_source: str, | |
| ) -> Optional[LabelInput]: | |
| source_type = str(record.get("source_type") or "") | |
| if not _source_allowed(source_type, only_source): | |
| return None | |
| image_path = str(record.get("image_path") or "").strip() | |
| if not image_path: | |
| logger.warning("Missing image_path in manifest record") | |
| return None | |
| image_abspath = Path(image_path) | |
| if not image_abspath.is_absolute(): | |
| image_abspath = (base_dir / image_abspath).resolve() | |
| item_name = str(record.get("item_name") or "").strip() | |
| if not item_name: | |
| logger.warning("Missing item_name in manifest record") | |
| return None | |
| return LabelInput( | |
| image_path=image_path, | |
| image_abspath=image_abspath, | |
| image_url=_optional_str(record.get("image_url")), | |
| image_sha256=_optional_str(record.get("image_sha256")), | |
| item_name=item_name, | |
| item_description=_optional_str(record.get("item_description")), | |
| item_part=_optional_str(record.get("item_part")), | |
| source_type=source_type, | |
| ocid=_optional_str(record.get("ocid")), | |
| ranking=_optional_int(record.get("ranking")), | |
| ) | |
| def _iter_db( | |
| db_path: Path, | |
| only_source: str, | |
| max_samples: Optional[int], | |
| run_id: Optional[str], | |
| ) -> Iterable[LabelInput]: | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| base_dir = db_path.parent | |
| def stream(query: str, params: tuple[object, ...], source_type: str) -> Iterable[LabelInput]: | |
| cursor = conn.execute(query, params) | |
| for row in cursor: | |
| local_path = row["local_path"] | |
| image_path = local_path or "" | |
| image_abspath = (base_dir / local_path).resolve() if local_path else None | |
| item_name = row["item_name"] or "" | |
| if not item_name: | |
| continue | |
| yield LabelInput( | |
| image_path=image_path, | |
| image_abspath=image_abspath, | |
| image_url=row["image_url"], | |
| image_sha256=row["sha256"], | |
| item_name=item_name, | |
| item_description=row["item_description"], | |
| item_part=_build_item_part(row["item_part"], row["item_slot"]), | |
| source_type=source_type, | |
| ocid=row["ocid"], | |
| ranking=None, | |
| ) | |
| count = 0 | |
| if only_source in ("equipment_shape", "all"): | |
| query, params = _equipment_query(run_id) | |
| for sample in stream(query, params, "equipment_shape"): | |
| yield sample | |
| count += 1 | |
| if max_samples and count >= max_samples: | |
| conn.close() | |
| return | |
| if only_source in ("cash", "all"): | |
| query, params = _cash_query(run_id) | |
| for sample in stream(query, params, "cash"): | |
| yield sample | |
| count += 1 | |
| if max_samples and count >= max_samples: | |
| conn.close() | |
| return | |
| conn.close() | |
| def _equipment_query(run_id: Optional[str]) -> tuple[str, tuple[object, ...]]: | |
| query = ( | |
| "SELECT e.item_shape_icon_url AS image_url, a.sha256 AS sha256, a.local_path AS local_path, " | |
| "e.item_name AS item_name, e.item_description AS item_description, " | |
| "e.item_equipment_part AS item_part, e.equipment_slot AS item_slot, e.ocid AS ocid " | |
| "FROM equipment_shape_items e " | |
| "LEFT JOIN icon_assets a ON a.url = e.item_shape_icon_url " | |
| "WHERE e.item_shape_icon_url IS NOT NULL AND e.item_shape_icon_url != ''" | |
| ) | |
| if run_id: | |
| query += " AND e.run_id = ?" | |
| return query, (run_id,) | |
| return query, () | |
| def _cash_query(run_id: Optional[str]) -> tuple[str, tuple[object, ...]]: | |
| query = ( | |
| "SELECT c.cash_item_icon_url AS image_url, a.sha256 AS sha256, a.local_path AS local_path, " | |
| "c.cash_item_name AS item_name, c.cash_item_description AS item_description, " | |
| "c.cash_item_equipment_part AS item_part, c.cash_item_equipment_slot AS item_slot, c.ocid AS ocid " | |
| "FROM cash_items c " | |
| "LEFT JOIN icon_assets a ON a.url = c.cash_item_icon_url " | |
| "WHERE c.cash_item_icon_url IS NOT NULL AND c.cash_item_icon_url != ''" | |
| ) | |
| if run_id: | |
| query += " AND c.run_id = ?" | |
| return query, (run_id,) | |
| return query, () | |
| def _source_allowed(source_type: str, only_source: str) -> bool: | |
| if only_source == "all": | |
| return source_type in {"equipment_shape", "cash"} | |
| return source_type == only_source | |
| def _build_item_part(part: Optional[str], slot: Optional[str]) -> Optional[str]: | |
| part = (part or "").strip() | |
| slot = (slot or "").strip() | |
| if part and slot and part != slot: | |
| return f"{part}/{slot}" | |
| return part or slot or None | |
| def _optional_str(value: object) -> Optional[str]: | |
| if value is None: | |
| return None | |
| text = str(value).strip() | |
| return text or None | |
| def _optional_int(value: object) -> Optional[int]: | |
| try: | |
| return int(value) if value is not None else None | |
| except (TypeError, ValueError): | |
| return None | |