G-Paris
Initialize Space with LFS for images
19f31ed
from .schemas import GlobalStore, ObjectState, SelectorInput, ProjectState
from .inference import search_objects, refine_object
from .dataset_manager import DatasetManager
from .view_helpers import draw_candidates
from PIL import Image
import numpy as np
import os
import shutil
import uuid
import cv2
class AppController:
def __init__(self):
self.store = GlobalStore()
self.current_image = None # PIL Image
self.current_image_path = None # Path to current image
# Playlist state
self.project = ProjectState()
self.global_class_map = {} # Map class_name -> int ID
self.active_project_path = None # Path to the current project JSON file
def load_playlist(self, file_paths: list[str]):
"""Load a list of image paths."""
# Filter for images
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
playlist = sorted([p for p in file_paths if os.path.splitext(p)[1].lower() in valid_exts])
self.project = ProjectState(playlist=playlist)
self.current_image = None
self.current_image_path = None
self.store = GlobalStore()
if self.project.playlist:
return self.load_image_at_index(0)
return None
def load_image_at_index(self, index: int):
if not self.project.playlist or index < 0 or index >= len(self.project.playlist):
return None
# Save current state if we have an image loaded
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
self.project.current_index = index
path = self.project.playlist[index]
try:
image = Image.open(path).convert("RGB")
self.current_image = image
self.current_image_path = path
# Restore store if exists, else new
if path in self.project.annotations:
self.store = self.project.annotations[path]
else:
self.store = GlobalStore(image_path=path)
return image
except Exception as e:
print(f"Error loading image {path}: {e}")
return None
def next_image(self):
return self.load_image_at_index(self.project.current_index + 1)
def prev_image(self):
return self.load_image_at_index(self.project.current_index - 1)
def set_image(self, image: Image.Image):
# Legacy support: treat as single image playlist without path
# This might break if we rely on paths for export.
# Ideally we force file upload.
# For now, let's just set it and reset store, but warn it won't work well with playlist export
self.current_image = image
self.current_image_path = None
self.store = GlobalStore()
self.project = ProjectState()
def reset_project(self):
"""Reset the project state completely."""
self.store = GlobalStore()
self.current_image = None
self.current_image_path = None
self.project = ProjectState()
self.global_class_map = {}
self.active_project_path = None
def auto_save(self):
"""Auto-save the project if an active path is set."""
if self.active_project_path:
print(f"💾 Auto-saving to {self.active_project_path}...")
return self.save_project(self.active_project_path)
return False, "No active project to save."
def update_history(self, prompt: str, class_name: str):
if prompt and prompt not in self.project.prompt_history:
self.project.prompt_history.append(prompt)
if class_name and class_name not in self.project.class_name_history:
self.project.class_name_history.append(class_name)
def update_history(self, prompt: str, class_name: str):
if prompt and prompt not in self.project.prompt_history:
self.project.prompt_history.append(prompt)
if class_name and class_name not in self.project.class_name_history:
self.project.class_name_history.append(class_name)
def search_and_add(self, class_name: str, search_boxes: list[list[int]] = [], search_labels: list[int] = [], class_name_override: str = None, crop_box: list[int] = None):
self.update_history(class_name, class_name_override)
if self.current_image is None: return []
# Create SelectorInput
selector_input = SelectorInput(
image=self.current_image,
text=class_name,
class_name_override=class_name_override,
input_boxes=search_boxes,
input_labels=search_labels,
crop_box=crop_box
)
candidates = search_objects(selector_input)
# We return candidates, but don't add to store yet (UI will decide)
return candidates
def add_candidates_to_store(self, candidates: list[ObjectState], selected_indices: list[int]):
added_ids = []
for idx in selected_indices:
if 0 <= idx < len(candidates):
obj_state = candidates[idx]
self.store.objects[obj_state.object_id] = obj_state
added_ids.append(obj_state.object_id)
return added_ids
def get_candidate_preview(self, candidates: list[ObjectState], selected_index: int | set | list = None):
"""Generate preview image with candidates drawn."""
if self.current_image is None or not candidates:
return self.current_image
return draw_candidates(self.current_image, candidates, selected_index)
def get_candidates_dataframe(self, candidates: list[ObjectState]):
"""Get dataframe for UI list."""
data = []
for i, obj in enumerate(candidates):
# Add ID column (i+1) to match the image labels
data.append([
i + 1, # ID
obj.class_name, # Class
f"{obj.score:.2f}" # Score
])
return data
def refine_object(self, obj_id: str, point: list[int], label: int):
if obj_id not in self.store.objects: return None
if self.current_image is None: return None
obj = self.store.objects[obj_id]
# Update history
obj.input_points.append(point)
obj.input_labels.append(label)
print(f"Refining {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}")
# Run Refiner
new_mask = refine_object(self.current_image, obj)
# Update Mask
obj.binary_mask = new_mask
return new_mask
def undo_last_point(self, obj_id: str):
if obj_id not in self.store.objects: return None
obj = self.store.objects[obj_id]
if not obj.input_points:
return obj.binary_mask # Nothing to undo
# Remove last
obj.input_points.pop()
obj.input_labels.pop()
# If no points left, revert to initial
if not obj.input_points:
obj.binary_mask = obj.initial_mask
return obj.binary_mask
# Otherwise re-run refinement
print(f"Refining (Undo) {obj_id}: Points={obj.input_points}, Labels={obj.input_labels}")
new_mask = refine_object(self.current_image, obj)
obj.binary_mask = new_mask
return new_mask
def remove_object(self, obj_id: str):
if obj_id in self.store.objects:
del self.store.objects[obj_id]
return True
return False
def revert_object(self, obj_id: str):
"""Revert object to its initial state (before refinement)."""
if obj_id not in self.store.objects: return None
obj = self.store.objects[obj_id]
# Reset to initial mask
obj.binary_mask = obj.initial_mask
# Clear points
obj.input_points = []
obj.input_labels = []
return obj.binary_mask
def export_data(self, output_dir: str, purge: bool = False, zip_output: bool = False):
"""Export all images and annotations in playlist to YOLO format."""
# Ensure current state is saved
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
if not self.project.annotations:
return None, "No annotations to export."
# Structure:
# output_dir/
# data.yaml
# images/
# train/
# labels/
# train/
images_dir = os.path.join(output_dir, "images", "train")
labels_dir = os.path.join(output_dir, "labels", "train")
if purge:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)
# Collect all unique class names to build map
all_class_names = set()
for store in self.project.annotations.values():
for obj in store.objects.values():
all_class_names.add(obj.class_name)
# Update global map (append new ones)
sorted_classes = sorted(list(all_class_names))
class_list = sorted_classes
class_map = {name: i for i, name in enumerate(class_list)}
exported_count = 0
for path, store in self.project.annotations.items():
if not store.objects:
continue
# Copy image
filename = os.path.basename(path)
dest_img_path = os.path.join(images_dir, filename)
shutil.copy2(path, dest_img_path)
# Generate Label File
label_filename = os.path.splitext(filename)[0] + ".txt"
dest_label_path = os.path.join(labels_dir, label_filename)
# We need image size for normalization.
try:
with Image.open(path) as img:
w, h = img.size
except:
print(f"Could not read image size for {path}")
continue
lines = []
for obj in store.objects.values():
cid = class_map.get(obj.class_name, 0)
mask = obj.binary_mask.astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
points = cnt.flatten()
if len(points) < 6: continue # Need at least 3 points
norm_points = []
for i in range(0, len(points), 2):
nx = points[i] / w
ny = points[i+1] / h
# Clip to 0-1
nx = max(0, min(1, nx))
ny = max(0, min(1, ny))
norm_points.extend([f"{nx:.6f}", f"{ny:.6f}"])
line = f"{cid} " + " ".join(norm_points)
lines.append(line)
with open(dest_label_path, "w") as f:
f.write("\n".join(lines))
exported_count += 1
# Create data.yaml
yaml_content = f"""names:
{chr(10).join([f" {i}: {name}" for i, name in enumerate(class_list)])}
path: .
train: images/train
"""
with open(os.path.join(output_dir, "data.yaml"), "w") as f:
f.write(yaml_content)
msg = f"Exported {exported_count} images to {output_dir}"
if zip_output:
# Determine zip name based on project name if available
zip_name = "dataset"
if self.active_project_path:
# Extract project name from path (e.g., "saved_projects/my_project.json" -> "my_project")
zip_name = os.path.splitext(os.path.basename(self.active_project_path))[0]
# Create a temp folder for staging the zip
parent_dir = os.path.dirname(os.path.abspath(output_dir))
temp_dir = os.path.join(parent_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
base_name = os.path.join(temp_dir, zip_name)
# Create zip in temp folder
zip_file = shutil.make_archive(base_name, 'zip', output_dir)
# Clear output_dir
for item in os.listdir(output_dir):
item_path = os.path.join(output_dir, item)
if os.path.isfile(item_path) or os.path.islink(item_path):
os.unlink(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
# Move zip to output_dir
final_name = f"{zip_name}.zip"
final_path = os.path.join(output_dir, final_name)
shutil.move(zip_file, final_path)
# Remove temp folder
shutil.rmtree(temp_dir)
msg += f" and zipped to {final_name} (original files deleted)"
return None, msg
def save_project(self, file_path: str):
"""Save project state to JSON and bundle images."""
import json
import os
import shutil
from .utils import mask_to_polygons
# Ensure current state is saved
if self.current_image_path:
self.project.annotations[self.current_image_path] = self.store
# Create assets directory
base_dir = os.path.dirname(file_path)
project_name = os.path.splitext(os.path.basename(file_path))[0]
assets_dir_name = f"{project_name}_assets"
assets_dir = os.path.join(base_dir, assets_dir_name)
os.makedirs(assets_dir, exist_ok=True)
# Map original paths to relative paths
path_map = {} # original -> relative
new_playlist = []
# Process playlist
for original_path in self.project.playlist:
filename = os.path.basename(original_path)
# Handle duplicate filenames by prepending index if needed?
# For now assume unique filenames or just overwrite (simple)
# Better: check collision
dest_path = os.path.join(assets_dir, filename)
# Copy file if it doesn't exist or if we want to ensure it's there
try:
if not os.path.exists(dest_path) or os.path.abspath(original_path) != os.path.abspath(dest_path):
shutil.copy2(original_path, dest_path)
except Exception as e:
print(f"Warning: Failed to copy {original_path} to {dest_path}: {e}")
# Store relative path
relative_path = os.path.join(assets_dir_name, filename)
path_map[original_path] = relative_path
new_playlist.append(relative_path)
data = {
"playlist": new_playlist,
"current_index": self.project.current_index,
"prompt_history": self.project.prompt_history,
"class_name_history": self.project.class_name_history,
"annotations": {}
}
for path, store in self.project.annotations.items():
# Get the new relative path key
new_key = path_map.get(path)
if not new_key:
# If annotation exists for a file not in playlist (shouldn't happen but safe fallback)
filename = os.path.basename(path)
new_key = os.path.join(assets_dir_name, filename)
objects_data = {}
for obj_id, obj in store.objects.items():
objects_data[obj_id] = {
"object_id": obj.object_id,
"score": obj.score,
"class_name": obj.class_name,
"anchor_box": obj.anchor_box,
"input_points": obj.input_points,
"input_labels": obj.input_labels,
"polygons": mask_to_polygons(obj.binary_mask)
}
data["annotations"][new_key] = objects_data
try:
with open(file_path, 'w') as f:
json.dump(data, f, indent=2)
# Update active project path
self.active_project_path = file_path
return True, f"Project saved to {file_path} (Images bundled in {assets_dir_name})"
except Exception as e:
return False, f"Failed to save project: {e}"
def load_project(self, file_path: str):
"""Load project state from JSON."""
import json
import os
from .utils import polygons_to_mask
try:
with open(file_path, 'r') as f:
data = json.load(f)
except Exception as e:
return False, f"Failed to load file: {e}"
base_dir = os.path.dirname(file_path)
# Reconstruct absolute paths for playlist
loaded_playlist = []
for rel_path in data.get("playlist", []):
abs_path = os.path.abspath(os.path.join(base_dir, rel_path))
loaded_playlist.append(abs_path)
# Restore Project State
self.project = ProjectState(
playlist=loaded_playlist,
current_index=data.get("current_index", -1),
prompt_history=data.get("prompt_history", []),
class_name_history=data.get("class_name_history", [])
)
# Restore Annotations
missing_files = []
for rel_path, objects_data in data.get("annotations", {}).items():
abs_path = os.path.abspath(os.path.join(base_dir, rel_path))
store = GlobalStore(image_path=abs_path)
# Need image size to restore masks
try:
with Image.open(abs_path) as img:
w, h = img.size
except:
print(f"Warning: Could not read image {abs_path} during load. Skipping masks.")
missing_files.append(abs_path)
continue
for obj_id, obj_data in objects_data.items():
# Reconstruct mask
polygons = obj_data.get("polygons", [])
mask = polygons_to_mask(polygons, w, h)
obj = ObjectState(
object_id=obj_data["object_id"],
score=obj_data["score"],
class_name=obj_data["class_name"],
anchor_box=obj_data["anchor_box"],
binary_mask=mask,
initial_mask=mask.copy(), # Assume loaded state is initial
input_points=obj_data.get("input_points", []),
input_labels=obj_data.get("input_labels", [])
)
store.objects[obj_id] = obj
self.project.annotations[abs_path] = store
# Load current image
if self.project.current_index >= 0:
self.load_image_at_index(self.project.current_index)
msg = f"Project loaded from {file_path}"
if missing_files:
msg += f". Warning: {len(missing_files)} images not found (annotations skipped)."
# Update active project path
self.active_project_path = file_path
return True, msg
def get_all_masks(self):
return [(obj.binary_mask, f"{obj.class_name}") for obj in self.store.objects.values()]
def get_object_mask(self, obj_id):
if obj_id in self.store.objects:
return self.store.objects[obj_id].binary_mask
return None
def clean_and_export_dataset(self, dataset_path, tolerance_ratio=0.000805, min_area_ratio=0.000219):
"""Clean, validate, and zip a YOLO dataset."""
manager = DatasetManager(dataset_path)
# 1. Remove Zone.Identifier files
manager.remove_zone_identifiers()
# 2. Clean dataset (in-place)
print(f"Cleaning dataset at {dataset_path}...")
stats = manager.cleanup_dataset(tolerance_ratio, min_area_ratio)
# 3. Finalize (Validation folders + Zip)
print("Finalizing dataset...")
zip_path = manager.finalize_dataset(create_zip=True)
return stats, zip_path
# Global Controller
controller = AppController()