sample / face_labeler.py
Silly98's picture
Upload face_labeler.py
bf9994d verified
raw
history blame
12.7 kB
import sys
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple
# Qt binding selection + pythonocc backend init
try:
from PyQt5 import QtCore, QtWidgets
_qt_backend = "qt-pyqt5"
except ImportError: # pragma: no cover
from PySide2 import QtCore, QtWidgets
_qt_backend = "qt-pyside2"
from OCC.Display.backend import load_backend
load_backend(_qt_backend)
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.IFSelect import IFSelect_RetDone
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopoDS import topods
from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB
from OCC.Display.qtDisplay import qtViewer3d
CLASS_NAMES = [
"ExtrudeSide",
"ExtrudeEnd",
"CutSide",
"CutEnd",
"Fillet",
"Chamfer",
"RevolveSide",
"RevolveEnd",
]
CLASS_COLORS_HEX = [
"#e41a1c", # red
"#377eb8", # blue
"#4daf4a", # green
"#984ea3", # purple
"#0e2579", # teal-blue (fillet)
"#a65628", # brown
"#f781bf", # pink
"#00a7a7", # teal
]
UNLABELED_COLOR_HEX = "#d0d0d0"
HIGHLIGHT_COLOR_HEX = "#FFD400" # fixed for current selection
def hex_to_rgb01(color_hex: str) -> Tuple[float, float, float]:
color_hex = color_hex.lstrip("#")
r = int(color_hex[0:2], 16) / 255.0
g = int(color_hex[2:4], 16) / 255.0
b = int(color_hex[4:6], 16) / 255.0
return r, g, b
def rgb01_to_quantity(rgb: Tuple[float, float, float]) -> Quantity_Color:
return Quantity_Color(rgb[0], rgb[1], rgb[2], Quantity_TOC_RGB)
def text_color_for_bg(color_hex: str) -> str:
r, g, b = hex_to_rgb01(color_hex)
luminance = (0.299 * r + 0.587 * g + 0.114 * b)
return "#000000" if luminance > 0.6 else "#ffffff"
@dataclass
class FaceItem:
face: object
ais: object
class FaceLabeler(QtWidgets.QMainWindow):
def __init__(self, step_path: Optional[str] = None):
super().__init__()
self.setWindowTitle("BRep Face Labeler")
self.resize(1600, 1000)
self.setMinimumSize(1200, 800)
self.class_names = CLASS_NAMES
self.class_colors_rgb = [hex_to_rgb01(c) for c in CLASS_COLORS_HEX]
self.class_colors = [rgb01_to_quantity(c) for c in self.class_colors_rgb]
self.unlabeled_color = rgb01_to_quantity(hex_to_rgb01(UNLABELED_COLOR_HEX))
self.face_items: List[FaceItem] = []
self.labels: List[Optional[int]] = []
self.current_index: Optional[int] = None
self.highlight_enabled = True
self.highlight_color = rgb01_to_quantity(hex_to_rgb01(HIGHLIGHT_COLOR_HEX))
self._build_ui()
if step_path:
self.load_step(step_path)
def _build_ui(self) -> None:
central = QtWidgets.QWidget(self)
root_layout = QtWidgets.QHBoxLayout(central)
root_layout.setContentsMargins(8, 8, 8, 8)
root_layout.setSpacing(8)
self.viewer = qtViewer3d(central)
self.viewer.InitDriver()
self.display = self.viewer._display
try:
self.display.Context.SetAutomaticHilight(False)
except Exception:
pass
root_layout.addWidget(self.viewer, 1)
panel = QtWidgets.QWidget(central)
panel_layout = QtWidgets.QVBoxLayout(panel)
panel_layout.setContentsMargins(0, 0, 0, 0)
panel_layout.setSpacing(6)
root_layout.addWidget(panel, 0)
self.btn_import_step = QtWidgets.QPushButton("Import STEP")
self.btn_import_step.clicked.connect(self.on_import_step)
panel_layout.addWidget(self.btn_import_step)
self.btn_export_seg = QtWidgets.QPushButton("Export .seg")
self.btn_export_seg.clicked.connect(self.on_export_seg)
panel_layout.addWidget(self.btn_export_seg)
self.btn_review = QtWidgets.QPushButton("Review")
self.btn_review.clicked.connect(self.on_review)
panel_layout.addWidget(self.btn_review)
panel_layout.addSpacing(8)
nav_layout = QtWidgets.QHBoxLayout()
self.btn_prev = QtWidgets.QPushButton("<< Prev")
self.btn_prev.clicked.connect(self.on_prev)
self.btn_next = QtWidgets.QPushButton("Next >>")
self.btn_next.clicked.connect(self.on_next)
nav_layout.addWidget(self.btn_prev)
nav_layout.addWidget(self.btn_next)
panel_layout.addLayout(nav_layout)
self.info_label = QtWidgets.QLabel("No STEP loaded")
self.info_label.setWordWrap(True)
panel_layout.addWidget(self.info_label)
panel_layout.addSpacing(8)
legend_label = QtWidgets.QLabel("Assign Label")
legend_label.setStyleSheet("font-weight: bold;")
panel_layout.addWidget(legend_label)
grid = QtWidgets.QGridLayout()
grid.setSpacing(6)
for idx, name in enumerate(self.class_names):
btn = QtWidgets.QPushButton(f"{idx}: {name}")
bg = CLASS_COLORS_HEX[idx]
fg = text_color_for_bg(bg)
btn.setStyleSheet(f"background-color: {bg}; color: {fg};")
btn.clicked.connect(lambda checked=False, i=idx: self.assign_label(i))
grid.addWidget(btn, idx, 0)
panel_layout.addLayout(grid)
panel_layout.addStretch(1)
self.setCentralWidget(central)
def keyPressEvent(self, event) -> None: # pragma: no cover - UI only
if event.key() in (QtCore.Qt.Key_Right, QtCore.Qt.Key_D):
self.on_next()
return
if event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_A):
self.on_prev()
return
super().keyPressEvent(event)
def on_import_step(self) -> None:
path, _ = QtWidgets.QFileDialog.getOpenFileName(
self, "Open STEP", "", "STEP Files (*.stp *.step)"
)
if path:
self.load_step(path)
def on_export_seg(self) -> None:
if not self.labels:
QtWidgets.QMessageBox.warning(self, "Export", "No STEP loaded.")
return
if any(label is None for label in self.labels):
QtWidgets.QMessageBox.warning(
self,
"Export",
"Unlabeled faces remain. Label all faces before exporting.",
)
return
path, _ = QtWidgets.QFileDialog.getSaveFileName(
self, "Export .seg", "", "SEG Files (*.seg)"
)
if path:
self.save_seg(path)
def on_review(self) -> None:
if not self.labels:
QtWidgets.QMessageBox.information(self, "Review", "No STEP loaded.")
return
counts = [0 for _ in self.class_names]
unlabeled = []
for idx, label in enumerate(self.labels):
if label is None:
unlabeled.append(idx)
else:
counts[label] += 1
lines = [
f"Total faces: {len(self.labels)}",
f"Unlabeled: {len(unlabeled)}",
"",
]
for idx, name in enumerate(self.class_names):
lines.append(f"{idx} {name}: {counts[idx]}")
if unlabeled:
preview = ", ".join(str(i) for i in unlabeled[:20])
if len(unlabeled) > 20:
preview += ", ..."
lines.append("")
lines.append(f"Unlabeled indices: {preview}")
lines.append("")
lines.append("Jump to first unlabeled?")
res = QtWidgets.QMessageBox.question(
self,
"Review",
"\n".join(lines),
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
)
if res == QtWidgets.QMessageBox.Yes:
self.set_current_index(unlabeled[0])
else:
self.highlight_enabled = False
if self.current_index is not None:
self.set_face_color(
self.current_index, self.get_base_color(self.current_index)
)
QtWidgets.QMessageBox.information(self, "Review", "\n".join(lines))
def on_prev(self) -> None:
if self.current_index is None:
return
if self.current_index <= 0:
return
self.set_current_index(self.current_index - 1)
def on_next(self) -> None:
if self.current_index is None:
return
if self.current_index >= len(self.face_items) - 1:
return
self.set_current_index(self.current_index + 1)
def load_step(self, path: str) -> None:
reader = STEPControl_Reader()
status = reader.ReadFile(path)
if status != IFSelect_RetDone:
QtWidgets.QMessageBox.warning(
self, "Load STEP", f"Failed to read STEP file: {path}"
)
return
reader.TransferRoots()
shape = reader.OneShape()
self.display.EraseAll()
self.face_items.clear()
self.labels.clear()
self.highlight_enabled = True
self.current_index = None
explorer = TopExp_Explorer(shape, TopAbs_FACE)
while explorer.More():
face = topods.Face(explorer.Current())
ais = self.display.DisplayShape(face, update=False, color=self.unlabeled_color)
if isinstance(ais, list):
ais = ais[0]
try:
self.display.Context.SetDisplayMode(ais, 1, False)
except Exception:
pass
self.face_items.append(FaceItem(face=face, ais=ais))
self.labels.append(None)
explorer.Next()
if not self.face_items:
QtWidgets.QMessageBox.warning(
self, "Load STEP", "No faces found in STEP file."
)
self.display.Repaint()
return
self.display.FitAll()
self.set_current_index(0)
def save_seg(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
for label in self.labels:
handle.write(f"{label}\n")
QtWidgets.QMessageBox.information(self, "Export", f"Saved: {path}")
def assign_label(self, label_index: int) -> None:
if self.current_index is None:
return
self.labels[self.current_index] = label_index
self.apply_current_highlight(self.current_index)
self.update_info()
def update_info(self) -> None:
if self.current_index is None:
self.info_label.setText("No STEP loaded")
return
label = self.labels[self.current_index]
label_text = "Unlabeled" if label is None else f"{label}: {self.class_names[label]}"
self.info_label.setText(
f"Face {self.current_index + 1}/{len(self.face_items)}\n"
f"Label: {label_text}"
)
def get_base_color(self, index: int) -> Quantity_Color:
label = self.labels[index]
return self.unlabeled_color if label is None else self.class_colors[label]
def get_highlight_color(self, index: int) -> Quantity_Color:
return self.highlight_color
def set_current_index(self, index: int) -> None:
if not self.face_items:
return
index = max(0, min(index, len(self.face_items) - 1))
if self.current_index is not None:
self.set_face_color(self.current_index, self.get_base_color(self.current_index))
self.current_index = index
self.apply_current_highlight(self.current_index)
self.update_info()
def apply_current_highlight(self, index: int) -> None:
if self.highlight_enabled:
self.set_face_color(index, self.get_highlight_color(index))
else:
self.set_face_color(index, self.get_base_color(index))
def set_face_color(self, index: int, color: Quantity_Color) -> None:
ais = self.face_items[index].ais
if isinstance(ais, list):
for item in ais:
self._set_ais_color(item, color)
else:
self._set_ais_color(ais, color)
self.display.Repaint()
def _set_ais_color(self, ais, color: Quantity_Color) -> None:
try:
ais.SetColor(color)
except Exception:
self.display.Context.SetColor(ais, color, False)
self.display.Context.Redisplay(ais, False)
def main() -> int:
app = QtWidgets.QApplication(sys.argv)
step_path = None
if len(sys.argv) > 1:
candidate = sys.argv[1]
if os.path.exists(candidate):
step_path = candidate
window = FaceLabeler(step_path=step_path)
window.show()
return app.exec_()
if __name__ == "__main__":
raise SystemExit(main())