Spaces:
Running
Running
| from .schemas import GlobalStore, ObjectState, SelectorInput, ProjectState | |
| from .inference import search_objects, refine_object | |
| from .dataset_manager import DatasetManager | |
| from .view_helpers import draw_candidates | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import shutil | |
| import uuid | |
| import cv2 | |
| class AppController: | |
| def __init__(self): | |
| self.store = GlobalStore() | |
| self.current_image = None # PIL Image | |
| self.current_image_path = None # Path to current image | |
| # Playlist state | |
| self.project = ProjectState() | |
| self.global_class_map = {} # Map class_name -> int ID | |
| self.active_project_path = None # Path to the current project JSON file | |
| def load_playlist(self, file_paths: list[str]): | |
| """Load a list of image paths.""" | |
| # Filter for images | |
| valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} | |
| playlist = sorted([p for p in file_paths if os.path.splitext(p)[1].lower() in valid_exts]) | |
| self.project = ProjectState(playlist=playlist) | |
| self.current_image = None | |
| self.current_image_path = None | |
| self.store = GlobalStore() | |
| if self.project.playlist: | |
| return self.load_image_at_index(0) | |
| return None | |
| def load_image_at_index(self, index: int): | |
| if not self.project.playlist or index < 0 or index >= len(self.project.playlist): | |
| return None | |
| # Save current state if we have an image loaded | |
| if self.current_image_path: | |
| self.project.annotations[self.current_image_path] = self.store | |
| self.project.current_index = index | |
| path = self.project.playlist[index] | |
| try: | |
| image = Image.open(path).convert("RGB") | |
| self.current_image = image | |
| self.current_image_path = path | |
| # Restore store if exists, else new | |
| if path in self.project.annotations: | |
| self.store = self.project.annotations[path] | |
| else: | |
| self.store = GlobalStore(image_path=path) | |
| return image | |
| except Exception as e: | |
| print(f"Error loading image {path}: {e}") | |
| return None | |
| def next_image(self): | |
| return self.load_image_at_index(self.project.current_index + 1) | |
| def prev_image(self): | |
| return self.load_image_at_index(self.project.current_index - 1) | |
| def set_image(self, image: Image.Image): | |
| # Legacy support: treat as single image playlist without path | |
| # This might break if we rely on paths for export. | |
| # Ideally we force file upload. | |
| # For now, let's just set it and reset store, but warn it won't work well with playlist export | |
| self.current_image = image | |
| self.current_image_path = None | |
| self.store = GlobalStore() | |
| self.project = ProjectState() | |
| def reset_project(self): | |
| """Reset the project state completely.""" | |
| self.store = GlobalStore() | |
| self.current_image = None | |
| self.current_image_path = None | |
| self.project = ProjectState() | |
| self.global_class_map = {} | |
| self.active_project_path = None | |
| def auto_save(self): | |
| """Auto-save the project if an active path is set.""" | |
| if self.active_project_path: | |
| print(f"💾 Auto-saving to {self.active_project_path}...") | |
| return self.save_project(self.active_project_path) | |
| return False, "No active project to save." | |
| def update_history(self, prompt: str, class_name: str): | |
| if prompt and prompt not in self.project.prompt_history: | |
| self.project.prompt_history.append(prompt) | |
| if class_name and class_name not in self.project.class_name_history: | |
| self.project.class_name_history.append(class_name) | |
| def update_history(self, prompt: str, class_name: str): | |
| if prompt and prompt not in self.project.prompt_history: | |
| self.project.prompt_history.append(prompt) | |
| if class_name and class_name not in self.project.class_name_history: | |
| self.project.class_name_history.append(class_name) | |
| def search_and_add(self, class_name: str, search_boxes: list[list[int]] = [], search_labels: list[int] = [], class_name_override: str = None, crop_box: list[int] = None): | |
| self.update_history(class_name, class_name_override) | |
| if self.current_image is None: return [] | |
| # Create SelectorInput | |
| selector_input = SelectorInput( | |
| image=self.current_image, | |
| text=class_name, | |
| class_name_override=class_name_override, | |
| input_boxes=search_boxes, | |
| input_labels=search_labels, | |
| crop_box=crop_box | |
| ) | |
| candidates = search_objects(selector_input) | |
| # We return candidates, but don't add to store yet (UI will decide) | |
| return candidates | |
| def add_candidates_to_store(self, candidates: list[ObjectState], selected_indices: list[int]): | |
| added_ids = [] | |
| for idx in selected_indices: | |
| if 0 <= idx < len(candidates): | |
| obj_state = candidates[idx] | |
| self.store.objects[obj_state.object_id] = obj_state | |
| added_ids.append(obj_state.object_id) | |
| return added_ids | |
| def get_candidate_preview(self, candidates: list[ObjectState], selected_index: int | set | list = None): | |
| """Generate preview image with candidates drawn.""" | |
| if self.current_image is None or not candidates: | |
| return self.current_image | |
| return draw_candidates(self.current_image, candidates, selected_index) | |
| def get_candidates_dataframe(self, candidates: list[ObjectState]): | |
| """Get dataframe for UI list.""" | |
| data = [] | |
| for i, obj in enumerate(candidates): | |
| # Add ID column (i+1) to match the image labels | |
| data.append([ | |
| i + 1, # ID | |
| obj.class_name, # Class | |
| f"{obj.score:.2f}" # Score | |
| ]) | |
| return data | |
| def refine_object(self, obj_id: str, point: list[int], label: int): | |
| if obj_id not in self.store.objects: return None | |
| if self.current_image is None: return None | |
| obj = self.store.objects[obj_id] | |
| # Update history | |
| obj.input_points.append(point) | |
| obj.input_labels.append(label) | |
| print(f"Refining {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}") | |
| # Run Refiner | |
| new_mask = refine_object(self.current_image, obj) | |
| # Update Mask | |
| obj.binary_mask = new_mask | |
| return new_mask | |
| def undo_last_point(self, obj_id: str): | |
| if obj_id not in self.store.objects: return None | |
| obj = self.store.objects[obj_id] | |
| if not obj.input_points: | |
| return obj.binary_mask # Nothing to undo | |
| # Remove last | |
| obj.input_points.pop() | |
| obj.input_labels.pop() | |
| # If no points left, revert to initial | |
| if not obj.input_points: | |
| obj.binary_mask = obj.initial_mask | |
| return obj.binary_mask | |
| # Otherwise re-run refinement | |
| print(f"Refining (Undo) {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}") | |
| new_mask = refine_object(self.current_image, obj) | |
| obj.binary_mask = new_mask | |
| return new_mask | |
| def remove_object(self, obj_id: str): | |
| if obj_id in self.store.objects: | |
| del self.store.objects[obj_id] | |
| return True | |
| return False | |
| def revert_object(self, obj_id: str): | |
| """Revert object to its initial state (before refinement).""" | |
| if obj_id not in self.store.objects: return None | |
| obj = self.store.objects[obj_id] | |
| # Reset to initial mask | |
| obj.binary_mask = obj.initial_mask | |
| # Clear points | |
| obj.input_points = [] | |
| obj.input_labels = [] | |
| return obj.binary_mask | |
| def export_data(self, output_dir: str, purge: bool = False, zip_output: bool = False): | |
| """Export all images and annotations in playlist to YOLO format.""" | |
| # Ensure current state is saved | |
| if self.current_image_path: | |
| self.project.annotations[self.current_image_path] = self.store | |
| if not self.project.annotations: | |
| return None, "No annotations to export." | |
| # Structure: | |
| # output_dir/ | |
| # data.yaml | |
| # images/ | |
| # train/ | |
| # labels/ | |
| # train/ | |
| images_dir = os.path.join(output_dir, "images", "train") | |
| labels_dir = os.path.join(output_dir, "labels", "train") | |
| if purge: | |
| if os.path.exists(output_dir): | |
| shutil.rmtree(output_dir) | |
| os.makedirs(images_dir, exist_ok=True) | |
| os.makedirs(labels_dir, exist_ok=True) | |
| # Collect all unique class names to build map | |
| all_class_names = set() | |
| for store in self.project.annotations.values(): | |
| for obj in store.objects.values(): | |
| all_class_names.add(obj.class_name) | |
| # Update global map (append new ones) | |
| sorted_classes = sorted(list(all_class_names)) | |
| class_list = sorted_classes | |
| class_map = {name: i for i, name in enumerate(class_list)} | |
| exported_count = 0 | |
| for path, store in self.project.annotations.items(): | |
| if not store.objects: | |
| continue | |
| # Copy image | |
| filename = os.path.basename(path) | |
| dest_img_path = os.path.join(images_dir, filename) | |
| shutil.copy2(path, dest_img_path) | |
| # Generate Label File | |
| label_filename = os.path.splitext(filename)[0] + ".txt" | |
| dest_label_path = os.path.join(labels_dir, label_filename) | |
| # We need image size for normalization. | |
| try: | |
| with Image.open(path) as img: | |
| w, h = img.size | |
| except: | |
| print(f"Could not read image size for {path}") | |
| continue | |
| lines = [] | |
| for obj in store.objects.values(): | |
| cid = class_map.get(obj.class_name, 0) | |
| mask = obj.binary_mask.astype(np.uint8) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| for cnt in contours: | |
| points = cnt.flatten() | |
| if len(points) < 6: continue # Need at least 3 points | |
| norm_points = [] | |
| for i in range(0, len(points), 2): | |
| nx = points[i] / w | |
| ny = points[i+1] / h | |
| # Clip to 0-1 | |
| nx = max(0, min(1, nx)) | |
| ny = max(0, min(1, ny)) | |
| norm_points.extend([f"{nx:.6f}", f"{ny:.6f}"]) | |
| line = f"{cid} " + " ".join(norm_points) | |
| lines.append(line) | |
| with open(dest_label_path, "w") as f: | |
| f.write("\n".join(lines)) | |
| exported_count += 1 | |
| # Create data.yaml | |
| yaml_content = f"""names: | |
| {chr(10).join([f" {i}: {name}" for i, name in enumerate(class_list)])} | |
| path: . | |
| train: images/train | |
| """ | |
| with open(os.path.join(output_dir, "data.yaml"), "w") as f: | |
| f.write(yaml_content) | |
| msg = f"Exported {exported_count} images to {output_dir}" | |
| if zip_output: | |
| # Determine zip name based on project name if available | |
| zip_name = "dataset" | |
| if self.active_project_path: | |
| # Extract project name from path (e.g., "saved_projects/my_project.json" -> "my_project") | |
| zip_name = os.path.splitext(os.path.basename(self.active_project_path))[0] | |
| # Create a temp folder for staging the zip | |
| parent_dir = os.path.dirname(os.path.abspath(output_dir)) | |
| temp_dir = os.path.join(parent_dir, "temp") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| base_name = os.path.join(temp_dir, zip_name) | |
| # Create zip in temp folder | |
| zip_file = shutil.make_archive(base_name, 'zip', output_dir) | |
| # Clear output_dir | |
| for item in os.listdir(output_dir): | |
| item_path = os.path.join(output_dir, item) | |
| if os.path.isfile(item_path) or os.path.islink(item_path): | |
| os.unlink(item_path) | |
| elif os.path.isdir(item_path): | |
| shutil.rmtree(item_path) | |
| # Move zip to output_dir | |
| final_name = f"{zip_name}.zip" | |
| final_path = os.path.join(output_dir, final_name) | |
| shutil.move(zip_file, final_path) | |
| # Remove temp folder | |
| shutil.rmtree(temp_dir) | |
| msg += f" and zipped to {final_name} (original files deleted)" | |
| return None, msg | |
| def save_project(self, file_path: str): | |
| """Save project state to JSON and bundle images.""" | |
| import json | |
| import os | |
| import shutil | |
| from .utils import mask_to_polygons | |
| # Ensure current state is saved | |
| if self.current_image_path: | |
| self.project.annotations[self.current_image_path] = self.store | |
| # Create assets directory | |
| base_dir = os.path.dirname(file_path) | |
| project_name = os.path.splitext(os.path.basename(file_path))[0] | |
| assets_dir_name = f"{project_name}_assets" | |
| assets_dir = os.path.join(base_dir, assets_dir_name) | |
| os.makedirs(assets_dir, exist_ok=True) | |
| # Map original paths to relative paths | |
| path_map = {} # original -> relative | |
| new_playlist = [] | |
| # Process playlist | |
| for original_path in self.project.playlist: | |
| filename = os.path.basename(original_path) | |
| # Handle duplicate filenames by prepending index if needed? | |
| # For now assume unique filenames or just overwrite (simple) | |
| # Better: check collision | |
| dest_path = os.path.join(assets_dir, filename) | |
| # Copy file if it doesn't exist or if we want to ensure it's there | |
| try: | |
| if not os.path.exists(dest_path) or os.path.abspath(original_path) != os.path.abspath(dest_path): | |
| shutil.copy2(original_path, dest_path) | |
| except Exception as e: | |
| print(f"Warning: Failed to copy {original_path} to {dest_path}: {e}") | |
| # Store relative path | |
| relative_path = os.path.join(assets_dir_name, filename) | |
| path_map[original_path] = relative_path | |
| new_playlist.append(relative_path) | |
| data = { | |
| "playlist": new_playlist, | |
| "current_index": self.project.current_index, | |
| "prompt_history": self.project.prompt_history, | |
| "class_name_history": self.project.class_name_history, | |
| "annotations": {} | |
| } | |
| for path, store in self.project.annotations.items(): | |
| # Get the new relative path key | |
| new_key = path_map.get(path) | |
| if not new_key: | |
| # If annotation exists for a file not in playlist (shouldn't happen but safe fallback) | |
| filename = os.path.basename(path) | |
| new_key = os.path.join(assets_dir_name, filename) | |
| objects_data = {} | |
| for obj_id, obj in store.objects.items(): | |
| objects_data[obj_id] = { | |
| "object_id": obj.object_id, | |
| "score": obj.score, | |
| "class_name": obj.class_name, | |
| "anchor_box": obj.anchor_box, | |
| "input_points": obj.input_points, | |
| "input_labels": obj.input_labels, | |
| "polygons": mask_to_polygons(obj.binary_mask) | |
| } | |
| data["annotations"][new_key] = objects_data | |
| try: | |
| with open(file_path, 'w') as f: | |
| json.dump(data, f, indent=2) | |
| # Update active project path | |
| self.active_project_path = file_path | |
| return True, f"Project saved to {file_path} (Images bundled in {assets_dir_name})" | |
| except Exception as e: | |
| return False, f"Failed to save project: {e}" | |
| def load_project(self, file_path: str): | |
| """Load project state from JSON.""" | |
| import json | |
| import os | |
| from .utils import polygons_to_mask | |
| try: | |
| with open(file_path, 'r') as f: | |
| data = json.load(f) | |
| except Exception as e: | |
| return False, f"Failed to load file: {e}" | |
| base_dir = os.path.dirname(file_path) | |
| # Reconstruct absolute paths for playlist | |
| loaded_playlist = [] | |
| for rel_path in data.get("playlist", []): | |
| abs_path = os.path.abspath(os.path.join(base_dir, rel_path)) | |
| loaded_playlist.append(abs_path) | |
| # Restore Project State | |
| self.project = ProjectState( | |
| playlist=loaded_playlist, | |
| current_index=data.get("current_index", -1), | |
| prompt_history=data.get("prompt_history", []), | |
| class_name_history=data.get("class_name_history", []) | |
| ) | |
| # Restore Annotations | |
| missing_files = [] | |
| for rel_path, objects_data in data.get("annotations", {}).items(): | |
| abs_path = os.path.abspath(os.path.join(base_dir, rel_path)) | |
| store = GlobalStore(image_path=abs_path) | |
| # Need image size to restore masks | |
| try: | |
| with Image.open(abs_path) as img: | |
| w, h = img.size | |
| except: | |
| print(f"Warning: Could not read image {abs_path} during load. Skipping masks.") | |
| missing_files.append(abs_path) | |
| continue | |
| for obj_id, obj_data in objects_data.items(): | |
| # Reconstruct mask | |
| polygons = obj_data.get("polygons", []) | |
| mask = polygons_to_mask(polygons, w, h) | |
| obj = ObjectState( | |
| object_id=obj_data["object_id"], | |
| score=obj_data["score"], | |
| class_name=obj_data["class_name"], | |
| anchor_box=obj_data["anchor_box"], | |
| binary_mask=mask, | |
| initial_mask=mask.copy(), # Assume loaded state is initial | |
| input_points=obj_data.get("input_points", []), | |
| input_labels=obj_data.get("input_labels", []) | |
| ) | |
| store.objects[obj_id] = obj | |
| self.project.annotations[abs_path] = store | |
| # Load current image | |
| if self.project.current_index >= 0: | |
| self.load_image_at_index(self.project.current_index) | |
| msg = f"Project loaded from {file_path}" | |
| if missing_files: | |
| msg += f". Warning: {len(missing_files)} images not found (annotations skipped)." | |
| # Update active project path | |
| self.active_project_path = file_path | |
| return True, msg | |
| def get_all_masks(self): | |
| return [(obj.binary_mask, f"{obj.class_name}") for obj in self.store.objects.values()] | |
| def get_object_mask(self, obj_id): | |
| if obj_id in self.store.objects: | |
| return self.store.objects[obj_id].binary_mask | |
| return None | |
| def clean_and_export_dataset(self, dataset_path, tolerance_ratio=0.000805, min_area_ratio=0.000219): | |
| """Clean, validate, and zip a YOLO dataset.""" | |
| manager = DatasetManager(dataset_path) | |
| # 1. Remove Zone.Identifier files | |
| manager.remove_zone_identifiers() | |
| # 2. Clean dataset (in-place) | |
| print(f"Cleaning dataset at {dataset_path}...") | |
| stats = manager.cleanup_dataset(tolerance_ratio, min_area_ratio) | |
| # 3. Finalize (Validation folders + Zip) | |
| print("Finalizing dataset...") | |
| zip_path = manager.finalize_dataset(create_zip=True) | |
| return stats, zip_path | |
| # Global Controller | |
| controller = AppController() | |