maple-data / src /labeler /adapters.py
gyubin02's picture
Initial commit
da3fe02
from __future__ import annotations
import json
import logging
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional
logger = logging.getLogger("labeler")
@dataclass
class LabelInput:
image_path: str
image_abspath: Optional[Path]
image_url: Optional[str]
image_sha256: Optional[str]
item_name: str
item_description: Optional[str]
item_part: Optional[str]
source_type: str
ocid: Optional[str]
ranking: Optional[int]
def iter_inputs(
*,
input_path: Optional[Path],
db_path: Optional[Path],
only_source: str,
max_samples: Optional[int],
run_id: Optional[str],
) -> Iterable[LabelInput]:
if input_path:
if input_path.suffix.lower() in {".jsonl", ".json"}:
yield from _iter_manifest_jsonl(input_path, only_source, max_samples)
elif input_path.suffix.lower() == ".parquet":
yield from _iter_manifest_parquet(input_path, only_source, max_samples)
else:
raise ValueError(f"Unsupported input format: {input_path}")
return
if not db_path:
raise ValueError("Provide --input or --db")
yield from _iter_db(db_path, only_source, max_samples, run_id)
def _iter_manifest_jsonl(
path: Path,
only_source: str,
max_samples: Optional[int],
) -> Iterable[LabelInput]:
base_dir = path.parent
count = 0
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError:
logger.warning("Skipping invalid JSON line in %s", path)
continue
sample = _build_from_record(record, base_dir, only_source)
if not sample:
continue
yield sample
count += 1
if max_samples and count >= max_samples:
break
def _iter_manifest_parquet(
path: Path,
only_source: str,
max_samples: Optional[int],
) -> Iterable[LabelInput]:
try:
import pyarrow.parquet as pq
except ImportError as exc:
raise RuntimeError("pyarrow is required for parquet input") from exc
base_dir = path.parent
table = pq.read_table(path)
rows = table.to_pylist()
count = 0
for record in rows:
sample = _build_from_record(record, base_dir, only_source)
if not sample:
continue
yield sample
count += 1
if max_samples and count >= max_samples:
break
def _build_from_record(
record: dict[str, object],
base_dir: Path,
only_source: str,
) -> Optional[LabelInput]:
source_type = str(record.get("source_type") or "")
if not _source_allowed(source_type, only_source):
return None
image_path = str(record.get("image_path") or "").strip()
if not image_path:
logger.warning("Missing image_path in manifest record")
return None
image_abspath = Path(image_path)
if not image_abspath.is_absolute():
image_abspath = (base_dir / image_abspath).resolve()
item_name = str(record.get("item_name") or "").strip()
if not item_name:
logger.warning("Missing item_name in manifest record")
return None
return LabelInput(
image_path=image_path,
image_abspath=image_abspath,
image_url=_optional_str(record.get("image_url")),
image_sha256=_optional_str(record.get("image_sha256")),
item_name=item_name,
item_description=_optional_str(record.get("item_description")),
item_part=_optional_str(record.get("item_part")),
source_type=source_type,
ocid=_optional_str(record.get("ocid")),
ranking=_optional_int(record.get("ranking")),
)
def _iter_db(
db_path: Path,
only_source: str,
max_samples: Optional[int],
run_id: Optional[str],
) -> Iterable[LabelInput]:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
base_dir = db_path.parent
def stream(query: str, params: tuple[object, ...], source_type: str) -> Iterable[LabelInput]:
cursor = conn.execute(query, params)
for row in cursor:
local_path = row["local_path"]
image_path = local_path or ""
image_abspath = (base_dir / local_path).resolve() if local_path else None
item_name = row["item_name"] or ""
if not item_name:
continue
yield LabelInput(
image_path=image_path,
image_abspath=image_abspath,
image_url=row["image_url"],
image_sha256=row["sha256"],
item_name=item_name,
item_description=row["item_description"],
item_part=_build_item_part(row["item_part"], row["item_slot"]),
source_type=source_type,
ocid=row["ocid"],
ranking=None,
)
count = 0
if only_source in ("equipment_shape", "all"):
query, params = _equipment_query(run_id)
for sample in stream(query, params, "equipment_shape"):
yield sample
count += 1
if max_samples and count >= max_samples:
conn.close()
return
if only_source in ("cash", "all"):
query, params = _cash_query(run_id)
for sample in stream(query, params, "cash"):
yield sample
count += 1
if max_samples and count >= max_samples:
conn.close()
return
conn.close()
def _equipment_query(run_id: Optional[str]) -> tuple[str, tuple[object, ...]]:
query = (
"SELECT e.item_shape_icon_url AS image_url, a.sha256 AS sha256, a.local_path AS local_path, "
"e.item_name AS item_name, e.item_description AS item_description, "
"e.item_equipment_part AS item_part, e.equipment_slot AS item_slot, e.ocid AS ocid "
"FROM equipment_shape_items e "
"LEFT JOIN icon_assets a ON a.url = e.item_shape_icon_url "
"WHERE e.item_shape_icon_url IS NOT NULL AND e.item_shape_icon_url != ''"
)
if run_id:
query += " AND e.run_id = ?"
return query, (run_id,)
return query, ()
def _cash_query(run_id: Optional[str]) -> tuple[str, tuple[object, ...]]:
query = (
"SELECT c.cash_item_icon_url AS image_url, a.sha256 AS sha256, a.local_path AS local_path, "
"c.cash_item_name AS item_name, c.cash_item_description AS item_description, "
"c.cash_item_equipment_part AS item_part, c.cash_item_equipment_slot AS item_slot, c.ocid AS ocid "
"FROM cash_items c "
"LEFT JOIN icon_assets a ON a.url = c.cash_item_icon_url "
"WHERE c.cash_item_icon_url IS NOT NULL AND c.cash_item_icon_url != ''"
)
if run_id:
query += " AND c.run_id = ?"
return query, (run_id,)
return query, ()
def _source_allowed(source_type: str, only_source: str) -> bool:
if only_source == "all":
return source_type in {"equipment_shape", "cash"}
return source_type == only_source
def _build_item_part(part: Optional[str], slot: Optional[str]) -> Optional[str]:
part = (part or "").strip()
slot = (slot or "").strip()
if part and slot and part != slot:
return f"{part}/{slot}"
return part or slot or None
def _optional_str(value: object) -> Optional[str]:
if value is None:
return None
text = str(value).strip()
return text or None
def _optional_int(value: object) -> Optional[int]:
try:
return int(value) if value is not None else None
except (TypeError, ValueError):
return None