Spaces:
Running
Running
| from fastapi import APIRouter, HTTPException, UploadFile, File, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import FileResponse | |
| from .ws_manager import manager | |
| from pydantic import BaseModel, field_validator | |
| from typing import List | |
| from PIL import Image, UnidentifiedImageError | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| import shutil | |
| from typing import List, Optional, Union, Dict, Any | |
| from . import utils | |
| import copy | |
| import traceback | |
| import asyncio | |
| import sys, signal | |
| import psutil | |
| import subprocess | |
| from . import common | |
| import fcntl | |
| from .config import Config, load_config, update_toml_key | |
| app = APIRouter() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await manager.connect(websocket) | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| # Handle any websocket messages if needed | |
| except WebSocketDisconnect: | |
| print("Client disconnected:", websocket.client) | |
| manager.disconnect(websocket) | |
| # === Configuration === | |
| config = load_config() | |
| IMAGE_LABEL_ROOT = os.path.join(config.current_path, "image_labels") | |
| CLASS_ID = 0 | |
| # === Pydantic Models === | |
| class Point(BaseModel): | |
| x: float | |
| y: float | |
| class Box(BaseModel): | |
| type: str = "bbox" # "bbox" or "segmentation" | |
| # For bbox | |
| left: Optional[int] = None | |
| top: Optional[int] = None | |
| width: Optional[int] = None | |
| height: Optional[int] = None | |
| # For segmentation | |
| points: Optional[List[Point]] = None | |
| # Common fields | |
| classId: int = CLASS_ID | |
| stroke: str = "#00ff00" | |
| strokeWidth: int = 3 | |
| fill: str = "rgba(0, 255, 0, 0.2)" | |
| saved: bool = True | |
| def round_floats(cls, v): | |
| return round(v) if v is not None else None | |
| class SaveAnnotationsRequest(BaseModel): | |
| annotations: List[Box] # Changed from 'boxes' to 'annotations' | |
| image_name: str | |
| original_width: int | |
| original_height: int | |
| class ImageInfo(BaseModel): | |
| name: str # Relative path like train/image1.jpg | |
| width: int | |
| height: int | |
| has_annotations: bool | |
| class TrainConfig(BaseModel): | |
| epoch: int # Relative path like train/image1.jpg | |
| batch: int | |
| imgsz: int | |
| recreate_dataset: bool | |
| resume_train: bool | |
| # === Helpers === | |
| def get_image_path(image_name: str) -> str: | |
| return os.path.join(config.IMAGE_SOURCE_PATH, image_name) | |
| def get_label_path(image_name: str) -> str: | |
| return os.path.join(IMAGE_LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt") | |
| # === Core Functions === | |
| def load_yolo_annotations(image_path: str, label_path: str, detect: bool = False): | |
| """Load both bbox and segmentation annotations from YOLO format""" | |
| try: | |
| img = Image.open(image_path) | |
| w, h = img.size | |
| annotations = [] | |
| # Auto-detect if needed | |
| normalise = False | |
| if detect and not os.path.exists(label_path): | |
| from .yolo_manager import YOLOManager | |
| with YOLOManager() as yolo_manager: | |
| weights_path = config.yolo_trained_model_path | |
| yolo_manager.load_model(weights_path) | |
| yolo_manager.annotate_images( | |
| image_paths=[image_path], | |
| output_dir=IMAGE_LABEL_ROOT, | |
| save_image=False, | |
| label_path=label_path | |
| ) | |
| normalise = True | |
| if os.path.exists(label_path): | |
| with open(label_path, "r") as f: | |
| for line in f: | |
| parts = list(map(float, line.strip().split())) | |
| if len(parts) < 5: | |
| continue | |
| class_id = int(parts[0]) | |
| if len(parts) == 5: # Bounding box format | |
| _, xc, yc, bw, bh = parts | |
| left = int((xc - bw / 2) * w) | |
| top = int((yc - bh / 2) * h) | |
| width = int(bw * w) | |
| height = int(bh * h) | |
| annotations.append({ | |
| "type": "bbox", | |
| "left": left, | |
| "top": top, | |
| "width": width, | |
| "height": height, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| }) | |
| elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation format | |
| # Skip class_id, then pairs of x,y coordinates | |
| coords = parts[1:] | |
| if len(coords) >= 6: # At least 3 points | |
| points = [] | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| x = coords[i] * w | |
| y = coords[i + 1] * h | |
| points.append({"x": x, "y": y}) | |
| annotations.append({ | |
| "type": "segmentation", | |
| "points": points, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| }) | |
| if normalise: | |
| annotations = utils.normalize_segmentation(annotations) | |
| save_yolo_annotations( | |
| copy.deepcopy(annotations), | |
| (w, h), | |
| label_path | |
| ) | |
| return annotations, (w, h) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading annotations: {str(e)} {traceback.format_exc()}") | |
| def normalize_annotations(annotations: List[Union[Box, dict]]) -> List[Box]: | |
| """Convert all annotations to Box objects.""" | |
| normalized = [] | |
| for ann in annotations: | |
| if isinstance(ann, Box): | |
| normalized.append(ann) | |
| elif isinstance(ann, dict): | |
| normalized.append(Box(**ann)) | |
| else: | |
| raise TypeError(f"Unsupported annotation type: {type(ann)}") | |
| return normalized | |
| def save_yolo_annotations(annotations: List[Box], original_size: tuple, label_path: str): | |
| """Save annotations in YOLO format (both bbox and segmentation)""" | |
| annotations = normalize_annotations(annotations) | |
| os.makedirs(os.path.dirname(label_path), exist_ok=True) | |
| w, h = original_size | |
| try: | |
| with open(label_path, "w") as f: | |
| # Generate YOLO format from annotations | |
| for annotation in annotations: | |
| if annotation.type == "bbox": | |
| left, top, width, height = annotation.left, annotation.top, annotation.width, annotation.height | |
| xc = (left + width / 2) / w | |
| yc = (top + height / 2) / h | |
| bw = width / w | |
| bh = height / h | |
| f.write(f"{annotation.classId} {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n") | |
| elif annotation.type == "segmentation" and annotation.points: | |
| # Convert points to normalized coordinates | |
| normalized_points = [] | |
| for point in annotation.points: | |
| normalized_points.extend([point.x / w, point.y / h]) | |
| coords_str = " ".join(f"{coord:.6f}" for coord in normalized_points) | |
| f.write(f"{annotation.classId} {coords_str}\n") | |
| # Copy to image_labels directory | |
| shutil.copy2(label_path, f"{IMAGE_LABEL_ROOT}/{os.path.basename(label_path)}") | |
| return True | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving annotations: {str(e)} {traceback.format_exc()}") | |
| def parse_yolo_line(line: str, image_width: int, image_height: int) -> Dict[str, Any]: | |
| """Parse a single YOLO format line and return annotation dict""" | |
| parts = list(map(float, line.strip().split())) | |
| if len(parts) < 5: | |
| return None | |
| class_id = int(parts[0]) | |
| if len(parts) == 5: # Bounding box | |
| _, xc, yc, bw, bh = parts | |
| left = int((xc - bw / 2) * image_width) | |
| top = int((yc - bh / 2) * image_height) | |
| width = int(bw * image_width) | |
| height = int(bh * image_height) | |
| return { | |
| "type": "bbox", | |
| "left": left, | |
| "top": top, | |
| "width": width, | |
| "height": height, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| } | |
| elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation | |
| coords = parts[1:] | |
| if len(coords) >= 6: # At least 3 points | |
| points = [] | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| x = coords[i] * image_width | |
| y = coords[i + 1] * image_height | |
| points.append({"x": x, "y": y}) | |
| return { | |
| "type": "segmentation", | |
| "points": points, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| } | |
| return None | |
| # === API Routes === | |
| async def list_all_images(): | |
| image_info_list = [] | |
| for root, _, files in os.walk(config.IMAGE_SOURCE_PATH): | |
| for file in sorted(files): | |
| if file.lower().endswith((".jpg", ".jpeg", ".png")): | |
| try: | |
| image_path = os.path.join(root, file) | |
| rel_path = os.path.relpath(image_path, config.IMAGE_SOURCE_PATH) | |
| label_path = get_label_path(rel_path) | |
| img = Image.open(image_path) | |
| width, height = img.size | |
| image_info_list.append(ImageInfo( | |
| name=rel_path.replace("\\", "/"), | |
| width=width, | |
| height=height, | |
| has_annotations=os.path.exists(label_path) | |
| )) | |
| except UnidentifiedImageError: | |
| print(f"Cannot identify image file: {image_path}") | |
| return image_info_list | |
| async def get_image(image_name: str): | |
| image_path = get_image_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| with Image.open(image_path) as img: | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| buffer = BytesIO() | |
| img.save(buffer, format="JPEG") | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| return { | |
| "image_data": f"data:image/jpeg;base64,{img_data}", | |
| "width": img.width, | |
| "height": img.height | |
| } | |
| async def get_annotations(image_name: str): | |
| image_path = get_image_path(image_name) | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| annotations, (width, height) = load_yolo_annotations(image_path, label_path) | |
| return { | |
| "annotations": annotations, | |
| "original_width": width, | |
| "original_height": height | |
| } | |
| async def get_detected_annotations(image_name: str): | |
| image_path = get_image_path(image_name) | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| annotations, (width, height) = load_yolo_annotations(image_path, label_path, True) | |
| return { | |
| "annotations": annotations, | |
| "original_width": width, | |
| "original_height": height | |
| } | |
| async def save_annotations(request: SaveAnnotationsRequest): | |
| label_path = get_label_path(request.image_name) | |
| success = save_yolo_annotations( | |
| request.annotations, | |
| (request.original_width, request.original_height), | |
| label_path | |
| ) | |
| return {"message": f"Saved {len(request.annotations)} annotations successfully"} | |
| async def delete_annotations(image_name: str): | |
| label_path = get_label_path(image_name) | |
| if os.path.exists(label_path): | |
| os.remove(label_path) | |
| return {"message": "Annotations deleted"} | |
| return {"message": "No annotations to delete"} | |
| async def download_annotations(image_name: str): | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(label_path): | |
| raise HTTPException(status_code=404, detail="Annotations not found") | |
| return FileResponse( | |
| label_path, | |
| media_type="text/plain", | |
| filename=os.path.basename(label_path) | |
| ) | |
| async def upload_image(file: UploadFile = File(...)): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| file_path = os.path.join(config.IMAGE_SOURCE_PATH, file.filename) | |
| with open(file_path, "wb") as f: | |
| f.write(await file.read()) | |
| return {"message": f"Uploaded {file.filename} to train set"} | |
| ####################### ----train---- ############################# | |
| current_process = {} | |
| def reset_current_process(): | |
| global current_process | |
| current_process = { | |
| "process": None | |
| } | |
| reset_current_process() | |
| # Define a function to handle cleanup | |
| def handle_exit(signal_received, frame): | |
| if current_process["process"]: | |
| os.killpg(os.getpgid(current_process['process'].pid), signal.SIGKILL) | |
| sys.exit(0) | |
| # Register the signal handler for SIGINT | |
| signal.signal(signal.SIGINT, handle_exit) | |
| async def get_config(): | |
| return { | |
| "epoch": config.EPOCH, | |
| "imgsz": config.DEFAULT_IMAGE_SIZE, | |
| "batch": config.BATCH, | |
| "resume_train": config.RESUME_TRAIN, | |
| "recreate_dataset": config.RECREATE_DATASET | |
| } | |
| async def save_config(request: TrainConfig): | |
| update_toml_key("EPOCH", request.epoch) | |
| update_toml_key("BATCH", request.batch) | |
| update_toml_key("DEFAULT_IMAGE_SIZE", request.imgsz) | |
| update_toml_key("RECREATE_DATASET", request.recreate_dataset) | |
| update_toml_key("RESUME_TRAIN", request.resume_train) | |
| return {'message': 'Config update successfully.', 'status': 'success'} | |
| async def reset_model(): | |
| from pathlib import Path | |
| file_path = Path(config.yolo_trained_model_path) | |
| if file_path.exists(): | |
| file_path.unlink() | |
| return {'message': 'Model Reseted', 'status': 'success'} | |
| async def deploy_model(app_name: str): | |
| from .yolo_manager import YOLOManager | |
| with YOLOManager() as yolo_manager: | |
| yolo_manager.deploy() | |
| return {'message': 'Model Deployed', 'status': 'success'} | |
| async def upload_image(): | |
| os.environ['PYTHONUNBUFFERED'] = "1" | |
| # Skip if the training process is already running | |
| if is_process_running("comic_panel_extractor.train"): | |
| return {"status": "ignored", "message": "Training already in progress."} | |
| reset_current_process() | |
| cmd_to_run="" | |
| if config.RECREATE_DATASET: | |
| cmd_to_run = "python -m comic_panel_extractor.create_dataset && " | |
| cmd_to_run += "python -m comic_panel_extractor.train" | |
| async def run_and_stream_output(): | |
| process = None | |
| try: | |
| process = subprocess.Popen( | |
| cmd_to_run, | |
| shell=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| bufsize=1, | |
| universal_newlines=True, | |
| preexec_fn=os.setsid, | |
| env={**os.environ, 'PYTHONUNBUFFERED': '1', 'CUDA_LAUNCH_BLOCKING': '1', 'USE_CPU_IF_POSSIBLE': str(common.get_device() == "cpu")} | |
| ) | |
| # Set non-blocking I/O | |
| fd = process.stdout.fileno() | |
| fl = fcntl.fcntl(fd, fcntl.F_GETFL) | |
| fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) | |
| current_process['process'] = process | |
| # Stream the output and send it via WebSocket in real-time | |
| while True: | |
| try: | |
| output = process.stdout.readline() | |
| if output: | |
| print(output.strip()) | |
| print("Active connections:", len(manager.active_connections)) | |
| asyncio.create_task(manager.broadcast({ | |
| 'type': 'command_output', | |
| 'data': output.strip() | |
| })) | |
| sys.stdout.flush() | |
| if process.poll() is not None: | |
| break | |
| # Small delay to prevent CPU spinning | |
| await asyncio.sleep(0.01) | |
| except Exception as e: | |
| print(f"Error reading process output: {e}") | |
| break | |
| # Process finished | |
| return_code = process.returncode if process else -1 | |
| asyncio.create_task(manager.broadcast({ | |
| 'type': 'command_finished', | |
| 'return_code': return_code | |
| })) | |
| except Exception as e: | |
| print(f"Error in run_and_stream_output: {e}") | |
| asyncio.create_task(manager.broadcast({ | |
| 'type': 'command_error', | |
| 'error': str(e) | |
| })) | |
| finally: | |
| current_process['process'] = None | |
| # Start the command execution in a separate task | |
| asyncio.create_task(run_and_stream_output()) | |
| return {"message": "Command started!", "status": "started"} | |
| async def stop_train(): | |
| try: | |
| # Check if there's actually a process to stop | |
| if current_process['process'] is None: | |
| return {'message': 'No command is currently running.', 'status': 'no_process'} | |
| # Check if process has already terminated naturally | |
| if current_process['process'].poll() is not None: | |
| # Process already finished, just clean up | |
| reset_current_process() | |
| return {'message': 'Command has already finished.', 'status': 'already_finished'} | |
| try: | |
| # Get the process group ID before attempting to kill | |
| pgid = os.getpgid(current_process['process'].pid) | |
| # Kill the entire process group | |
| os.killpg(pgid, signal.SIGTERM) # Try SIGTERM first | |
| # Wait a bit for graceful shutdown | |
| await asyncio.sleep(1) | |
| # If still running, force kill | |
| if current_process['process'] and current_process['process'].poll() is None: | |
| os.killpg(pgid, signal.SIGKILL) | |
| except ProcessLookupError: | |
| # Process already dead | |
| print("Process already terminated") | |
| except OSError as e: | |
| # Handle permission errors or other OS-level issues | |
| print(f"Error terminating process: {e}") | |
| # Try to kill just the main process if group kill fails | |
| try: | |
| current_process['process'].terminate() | |
| await asyncio.sleep(0.5) | |
| if current_process['process'].poll() is None: | |
| current_process['process'].kill() | |
| except: | |
| pass | |
| # Always reset the process state | |
| reset_current_process() | |
| # Notify connected clients | |
| await manager.broadcast({ | |
| 'type': 'command_stopped', | |
| 'message': 'Command terminated by user' | |
| }) | |
| return {'message': 'Command terminated successfully.', 'status': 'terminated'} | |
| except Exception as e: | |
| print(f"Error in stop_command: {str(e)}") | |
| # Force reset even if there was an error | |
| reset_current_process() | |
| raise HTTPException(status_code=500, detail=f'Error stopping command: {str(e)}') | |
| def is_process_running(name: str) -> bool: | |
| """ | |
| Check if a process containing 'name' in its command line is running. | |
| """ | |
| for proc in psutil.process_iter(['cmdline']): | |
| try: | |
| cmdline = " ".join(proc.info['cmdline']) if proc.info['cmdline'] else "" | |
| if name in cmdline: | |
| return True | |
| except (psutil.NoSuchProcess, psutil.AccessDenied): | |
| continue | |
| return False |