File size: 6,478 Bytes
7591256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
"""Shared utilities for 3DReflecNet HF release apps."""
from __future__ import annotations

import logging
from typing import Any

import pandas as pd

logger = logging.getLogger("hf_release")

FILTER_ALL = "ALL"
BOOL_FILTER_CHOICES = [FILTER_ALL, "True", "False"]


def setup_logging(level: int = logging.INFO) -> None:
    """Configure logging for hf_release modules."""
    logging.basicConfig(
        level=level,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


def require_columns(df: pd.DataFrame, columns: list[str], context: str) -> None:
    missing = [column for column in columns if column not in df.columns]
    if missing:
        raise KeyError(f"Missing required column(s) in {context}: {', '.join(missing)}")


def require_bool_columns(df: pd.DataFrame, columns: list[str], context: str) -> None:
    require_columns(df, columns, context)
    for column in columns:
        if df[column].isna().any():
            raise ValueError(f"Boolean column {column!r} contains null values in {context}.")
        if not pd.api.types.is_bool_dtype(df[column]):
            raise TypeError(f"Expected boolean dtype for column {column!r} in {context}, got {df[column].dtype}.")


def require_text_columns(df: pd.DataFrame, columns: list[str], context: str) -> None:
    require_columns(df, columns, context)
    for column in columns:
        if df[column].isna().any():
            raise ValueError(f"Text column {column!r} contains null values in {context}.")
        invalid = df[column].map(lambda value: not isinstance(value, str))
        if invalid.any():
            bad_type = type(df.loc[invalid, column].iloc[0]).__name__
            raise TypeError(f"Expected string values for column {column!r} in {context}, got {bad_type}.")


def parse_bool_filter_value(selected_value: str) -> bool:
    if selected_value == "True":
        return True
    if selected_value == "False":
        return False
    raise ValueError(f"Unsupported boolean filter value: {selected_value!r}")


def apply_bool_filter(df: pd.DataFrame, column: str, selected_value: str) -> pd.DataFrame:
    """Apply tri-state bool filter (ALL/True/False) to a DataFrame column."""
    if selected_value == FILTER_ALL:
        return df
    if column not in df.columns:
        raise KeyError(f"Missing required boolean filter column: {column}")
    if not pd.api.types.is_bool_dtype(df[column]):
        raise TypeError(f"Expected boolean dtype for column {column!r}, got {df[column].dtype}.")
    target = parse_bool_filter_value(selected_value)
    return df[df[column] == target]


def get_distinct_text_choices(df: pd.DataFrame, column: str, all_label: str = FILTER_ALL) -> list[str]:
    """Build dropdown choices from distinct non-empty text values."""
    if column not in df.columns:
        raise KeyError(f"Missing required text choice column: {column}")
    values = {
        str(v).strip()
        for v in df[column].dropna().tolist()
        if str(v).strip()
    }
    if not values:
        raise ValueError(f"Column {column!r} has no non-empty values.")
    return [all_label] + sorted(values)


def _apply_text_equals(df: pd.DataFrame, column: str, selected_value: str, all_label: str = FILTER_ALL) -> pd.DataFrame:
    if column not in df.columns:
        raise KeyError(f"Missing required text filter column: {column}")
    text = (selected_value or "").strip()
    if not text or text == all_label:
        return df
    return df[df[column].astype(str).str.strip() == text]


def filter_dataframe_advanced(
    df: pd.DataFrame,
    model_name: str = FILTER_ALL,
    material_name: str = FILTER_ALL,
    env_name: str = FILTER_ALL,
    has_glass: str = FILTER_ALL,
    is_generated: str = FILTER_ALL,
    transparent: str = FILTER_ALL,
    near_light: str = FILTER_ALL,
) -> pd.DataFrame:
    """Filter by model/material/environment exact selection and four tri-state bool fields."""
    selected = df
    selected = _apply_text_equals(selected, "model_name", model_name)
    selected = _apply_text_equals(selected, "material_name", material_name)
    selected = _apply_text_equals(selected, "env_name", env_name)
    selected = apply_bool_filter(selected, "hasGlass", has_glass)
    selected = apply_bool_filter(selected, "isGenerated", is_generated)
    selected = apply_bool_filter(selected, "transparent", transparent)
    selected = apply_bool_filter(selected, "near_light", near_light)
    return selected.reset_index(drop=True)


def aggregate_by_model(
    df: pd.DataFrame,
    extra_columns: list[str] | None = None,
) -> pd.DataFrame:
    """Group instances by model_name, counting instances and collecting IDs."""
    base_cols = ["model_name", "main_category", "sub_category", "instance_count", "instance_ids"]
    extra = extra_columns or []
    all_cols = base_cols + extra

    if df.empty:
        return pd.DataFrame(columns=all_cols)

    require_columns(df, ["model_name", "main_category", "sub_category", "instance_id"] + extra, "model aggregation")

    rows: list[dict[str, Any]] = []
    for model_name, group in df.groupby("model_name", dropna=False, sort=True):
        instance_ids = sorted({
            str(v) for v in group["instance_id"].dropna().tolist() if str(v).strip()
        })
        row: dict[str, Any] = {
            "model_name": str(model_name),
            "main_category": str(group["main_category"].iloc[0]),
            "sub_category": str(group["sub_category"].iloc[0]),
            "instance_count": len(instance_ids),
            "instance_ids": "\n".join(instance_ids),
        }
        for col in extra:
            candidates = [str(v) for v in group[col].dropna().tolist() if str(v).strip()]
            row[col] = candidates[0] if candidates else ""
        rows.append(row)
    return pd.DataFrame(rows)


def format_model_choice(index: int, row: dict[str, Any]) -> str:
    return f"{index:04d} | {row['model_name']} | instances {row['instance_count']}"


def format_instance_choice(index: int, row: dict[str, Any]) -> str:
    return f"{index:04d} | {row['instance_id']} | {row['model_name']}"


def parse_choice_index(choice: str, length: int) -> int | None:
    """Extract the numeric index from a formatted choice string."""
    index_str = choice.split("|", 1)[0].strip()
    try:
        idx = int(index_str)
    except ValueError:
        return None
    if idx < 0 or idx >= length:
        return None
    return idx