Spaces:
Running
Running
File size: 20,967 Bytes
19f31ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 |
from .schemas import GlobalStore, ObjectState, SelectorInput, ProjectState
from .inference import search_objects, refine_object
from .dataset_manager import DatasetManager
from .view_helpers import draw_candidates
from PIL import Image
import numpy as np
import os
import shutil
import uuid
import cv2
class AppController:
def __init__(self):
self.store = GlobalStore()
self.current_image = None # PIL Image
self.current_image_path = None # Path to current image
# Playlist state
self.project = ProjectState()
self.global_class_map = {} # Map class_name -> int ID
self.active_project_path = None # Path to the current project JSON file
def load_playlist(self, file_paths: list[str]):
"""Load a list of image paths."""
# Filter for images
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
playlist = sorted([p for p in file_paths if os.path.splitext(p)[1].lower() in valid_exts])
self.project = ProjectState(playlist=playlist)
self.current_image = None
self.current_image_path = None
self.store = GlobalStore()
if self.project.playlist:
return self.load_image_at_index(0)
return None
def load_image_at_index(self, index: int):
if not self.project.playlist or index < 0 or index >= len(self.project.playlist):
return None
# Save current state if we have an image loaded
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
self.project.current_index = index
path = self.project.playlist[index]
try:
image = Image.open(path).convert("RGB")
self.current_image = image
self.current_image_path = path
# Restore store if exists, else new
if path in self.project.annotations:
self.store = self.project.annotations[path]
else:
self.store = GlobalStore(image_path=path)
return image
except Exception as e:
print(f"Error loading image {path}: {e}")
return None
def next_image(self):
return self.load_image_at_index(self.project.current_index + 1)
def prev_image(self):
return self.load_image_at_index(self.project.current_index - 1)
def set_image(self, image: Image.Image):
# Legacy support: treat as single image playlist without path
# This might break if we rely on paths for export.
# Ideally we force file upload.
# For now, let's just set it and reset store, but warn it won't work well with playlist export
self.current_image = image
self.current_image_path = None
self.store = GlobalStore()
self.project = ProjectState()
def reset_project(self):
"""Reset the project state completely."""
self.store = GlobalStore()
self.current_image = None
self.current_image_path = None
self.project = ProjectState()
self.global_class_map = {}
self.active_project_path = None
def auto_save(self):
"""Auto-save the project if an active path is set."""
if self.active_project_path:
print(f"💾 Auto-saving to {self.active_project_path}...")
return self.save_project(self.active_project_path)
return False, "No active project to save."
def update_history(self, prompt: str, class_name: str):
if prompt and prompt not in self.project.prompt_history:
self.project.prompt_history.append(prompt)
if class_name and class_name not in self.project.class_name_history:
self.project.class_name_history.append(class_name)
def update_history(self, prompt: str, class_name: str):
if prompt and prompt not in self.project.prompt_history:
self.project.prompt_history.append(prompt)
if class_name and class_name not in self.project.class_name_history:
self.project.class_name_history.append(class_name)
def search_and_add(self, class_name: str, search_boxes: list[list[int]] = [], search_labels: list[int] = [], class_name_override: str = None, crop_box: list[int] = None):
self.update_history(class_name, class_name_override)
if self.current_image is None: return []
# Create SelectorInput
selector_input = SelectorInput(
image=self.current_image,
text=class_name,
class_name_override=class_name_override,
input_boxes=search_boxes,
input_labels=search_labels,
crop_box=crop_box
)
candidates = search_objects(selector_input)
# We return candidates, but don't add to store yet (UI will decide)
return candidates
def add_candidates_to_store(self, candidates: list[ObjectState], selected_indices: list[int]):
added_ids = []
for idx in selected_indices:
if 0 <= idx < len(candidates):
obj_state = candidates[idx]
self.store.objects[obj_state.object_id] = obj_state
added_ids.append(obj_state.object_id)
return added_ids
def get_candidate_preview(self, candidates: list[ObjectState], selected_index: int | set | list = None):
"""Generate preview image with candidates drawn."""
if self.current_image is None or not candidates:
return self.current_image
return draw_candidates(self.current_image, candidates, selected_index)
def get_candidates_dataframe(self, candidates: list[ObjectState]):
"""Get dataframe for UI list."""
data = []
for i, obj in enumerate(candidates):
# Add ID column (i+1) to match the image labels
data.append([
i + 1, # ID
obj.class_name, # Class
f"{obj.score:.2f}" # Score
])
return data
def refine_object(self, obj_id: str, point: list[int], label: int):
if obj_id not in self.store.objects: return None
if self.current_image is None: return None
obj = self.store.objects[obj_id]
# Update history
obj.input_points.append(point)
obj.input_labels.append(label)
print(f"Refining {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}")
# Run Refiner
new_mask = refine_object(self.current_image, obj)
# Update Mask
obj.binary_mask = new_mask
return new_mask
def undo_last_point(self, obj_id: str):
if obj_id not in self.store.objects: return None
obj = self.store.objects[obj_id]
if not obj.input_points:
return obj.binary_mask # Nothing to undo
# Remove last
obj.input_points.pop()
obj.input_labels.pop()
# If no points left, revert to initial
if not obj.input_points:
obj.binary_mask = obj.initial_mask
return obj.binary_mask
# Otherwise re-run refinement
print(f"Refining (Undo) {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}")
new_mask = refine_object(self.current_image, obj)
obj.binary_mask = new_mask
return new_mask
def remove_object(self, obj_id: str):
if obj_id in self.store.objects:
del self.store.objects[obj_id]
return True
return False
def revert_object(self, obj_id: str):
"""Revert object to its initial state (before refinement)."""
if obj_id not in self.store.objects: return None
obj = self.store.objects[obj_id]
# Reset to initial mask
obj.binary_mask = obj.initial_mask
# Clear points
obj.input_points = []
obj.input_labels = []
return obj.binary_mask
def export_data(self, output_dir: str, purge: bool = False, zip_output: bool = False):
"""Export all images and annotations in playlist to YOLO format."""
# Ensure current state is saved
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
if not self.project.annotations:
return None, "No annotations to export."
# Structure:
# output_dir/
# data.yaml
# images/
# train/
# labels/
# train/
images_dir = os.path.join(output_dir, "images", "train")
labels_dir = os.path.join(output_dir, "labels", "train")
if purge:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)
# Collect all unique class names to build map
all_class_names = set()
for store in self.project.annotations.values():
for obj in store.objects.values():
all_class_names.add(obj.class_name)
# Update global map (append new ones)
sorted_classes = sorted(list(all_class_names))
class_list = sorted_classes
class_map = {name: i for i, name in enumerate(class_list)}
exported_count = 0
for path, store in self.project.annotations.items():
if not store.objects:
continue
# Copy image
filename = os.path.basename(path)
dest_img_path = os.path.join(images_dir, filename)
shutil.copy2(path, dest_img_path)
# Generate Label File
label_filename = os.path.splitext(filename)[0] + ".txt"
dest_label_path = os.path.join(labels_dir, label_filename)
# We need image size for normalization.
try:
with Image.open(path) as img:
w, h = img.size
except:
print(f"Could not read image size for {path}")
continue
lines = []
for obj in store.objects.values():
cid = class_map.get(obj.class_name, 0)
mask = obj.binary_mask.astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
points = cnt.flatten()
if len(points) < 6: continue # Need at least 3 points
norm_points = []
for i in range(0, len(points), 2):
nx = points[i] / w
ny = points[i+1] / h
# Clip to 0-1
nx = max(0, min(1, nx))
ny = max(0, min(1, ny))
norm_points.extend([f"{nx:.6f}", f"{ny:.6f}"])
line = f"{cid} " + " ".join(norm_points)
lines.append(line)
with open(dest_label_path, "w") as f:
f.write("\n".join(lines))
exported_count += 1
# Create data.yaml
yaml_content = f"""names:
{chr(10).join([f" {i}: {name}" for i, name in enumerate(class_list)])}
path: .
train: images/train
"""
with open(os.path.join(output_dir, "data.yaml"), "w") as f:
f.write(yaml_content)
msg = f"Exported {exported_count} images to {output_dir}"
if zip_output:
# Determine zip name based on project name if available
zip_name = "dataset"
if self.active_project_path:
# Extract project name from path (e.g., "saved_projects/my_project.json" -> "my_project")
zip_name = os.path.splitext(os.path.basename(self.active_project_path))[0]
# Create a temp folder for staging the zip
parent_dir = os.path.dirname(os.path.abspath(output_dir))
temp_dir = os.path.join(parent_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
base_name = os.path.join(temp_dir, zip_name)
# Create zip in temp folder
zip_file = shutil.make_archive(base_name, 'zip', output_dir)
# Clear output_dir
for item in os.listdir(output_dir):
item_path = os.path.join(output_dir, item)
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# Move zip to output_dir
final_name = f"{zip_name}.zip"
final_path = os.path.join(output_dir, final_name)
shutil.move(zip_file, final_path)
# Remove temp folder
shutil.rmtree(temp_dir)
msg += f" and zipped to {final_name} (original files deleted)"
return None, msg
def save_project(self, file_path: str):
"""Save project state to JSON and bundle images."""
import json
import os
import shutil
from .utils import mask_to_polygons
# Ensure current state is saved
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
# Create assets directory
base_dir = os.path.dirname(file_path)
project_name = os.path.splitext(os.path.basename(file_path))[0]
assets_dir_name = f"{project_name}_assets"
assets_dir = os.path.join(base_dir, assets_dir_name)
os.makedirs(assets_dir, exist_ok=True)
# Map original paths to relative paths
path_map = {} # original -> relative
new_playlist = []
# Process playlist
for original_path in self.project.playlist:
filename = os.path.basename(original_path)
# Handle duplicate filenames by prepending index if needed?
# For now assume unique filenames or just overwrite (simple)
# Better: check collision
dest_path = os.path.join(assets_dir, filename)
# Copy file if it doesn't exist or if we want to ensure it's there
try:
if not os.path.exists(dest_path) or os.path.abspath(original_path) != os.path.abspath(dest_path):
shutil.copy2(original_path, dest_path)
except Exception as e:
print(f"Warning: Failed to copy {original_path} to {dest_path}: {e}")
# Store relative path
relative_path = os.path.join(assets_dir_name, filename)
path_map[original_path] = relative_path
new_playlist.append(relative_path)
data = {
"playlist": new_playlist,
"current_index": self.project.current_index,
"prompt_history": self.project.prompt_history,
"class_name_history": self.project.class_name_history,
"annotations": {}
}
for path, store in self.project.annotations.items():
# Get the new relative path key
new_key = path_map.get(path)
if not new_key:
# If annotation exists for a file not in playlist (shouldn't happen but safe fallback)
filename = os.path.basename(path)
new_key = os.path.join(assets_dir_name, filename)
objects_data = {}
for obj_id, obj in store.objects.items():
objects_data[obj_id] = {
"object_id": obj.object_id,
"score": obj.score,
"class_name": obj.class_name,
"anchor_box": obj.anchor_box,
"input_points": obj.input_points,
"input_labels": obj.input_labels,
"polygons": mask_to_polygons(obj.binary_mask)
}
data["annotations"][new_key] = objects_data
try:
with open(file_path, 'w') as f:
json.dump(data, f, indent=2)
# Update active project path
self.active_project_path = file_path
return True, f"Project saved to {file_path} (Images bundled in {assets_dir_name})"
except Exception as e:
return False, f"Failed to save project: {e}"
def load_project(self, file_path: str):
"""Load project state from JSON."""
import json
import os
from .utils import polygons_to_mask
try:
with open(file_path, 'r') as f:
data = json.load(f)
except Exception as e:
return False, f"Failed to load file: {e}"
base_dir = os.path.dirname(file_path)
# Reconstruct absolute paths for playlist
loaded_playlist = []
for rel_path in data.get("playlist", []):
abs_path = os.path.abspath(os.path.join(base_dir, rel_path))
loaded_playlist.append(abs_path)
# Restore Project State
self.project = ProjectState(
playlist=loaded_playlist,
current_index=data.get("current_index", -1),
prompt_history=data.get("prompt_history", []),
class_name_history=data.get("class_name_history", [])
)
# Restore Annotations
missing_files = []
for rel_path, objects_data in data.get("annotations", {}).items():
abs_path = os.path.abspath(os.path.join(base_dir, rel_path))
store = GlobalStore(image_path=abs_path)
# Need image size to restore masks
try:
with Image.open(abs_path) as img:
w, h = img.size
except:
print(f"Warning: Could not read image {abs_path} during load. Skipping masks.")
missing_files.append(abs_path)
continue
for obj_id, obj_data in objects_data.items():
# Reconstruct mask
polygons = obj_data.get("polygons", [])
mask = polygons_to_mask(polygons, w, h)
obj = ObjectState(
object_id=obj_data["object_id"],
score=obj_data["score"],
class_name=obj_data["class_name"],
anchor_box=obj_data["anchor_box"],
binary_mask=mask,
initial_mask=mask.copy(), # Assume loaded state is initial
input_points=obj_data.get("input_points", []),
input_labels=obj_data.get("input_labels", [])
)
store.objects[obj_id] = obj
self.project.annotations[abs_path] = store
# Load current image
if self.project.current_index >= 0:
self.load_image_at_index(self.project.current_index)
msg = f"Project loaded from {file_path}"
if missing_files:
msg += f". Warning: {len(missing_files)} images not found (annotations skipped)."
# Update active project path
self.active_project_path = file_path
return True, msg
def get_all_masks(self):
return [(obj.binary_mask, f"{obj.class_name}") for obj in self.store.objects.values()]
def get_object_mask(self, obj_id):
if obj_id in self.store.objects:
return self.store.objects[obj_id].binary_mask
return None
def clean_and_export_dataset(self, dataset_path, tolerance_ratio=0.000805, min_area_ratio=0.000219):
"""Clean, validate, and zip a YOLO dataset."""
manager = DatasetManager(dataset_path)
# 1. Remove Zone.Identifier files
manager.remove_zone_identifiers()
# 2. Clean dataset (in-place)
print(f"Cleaning dataset at {dataset_path}...")
stats = manager.cleanup_dataset(tolerance_ratio, min_area_ratio)
# 3. Finalize (Validation folders + Zip)
print("Finalizing dataset...")
zip_path = manager.finalize_dataset(create_zip=True)
return stats, zip_path
# Global Controller
controller = AppController()
|