Spaces:
Running
Running
| from fastapi import APIRouter, HTTPException, UploadFile, File | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, field_validator | |
| from typing import List | |
| from PIL import Image | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| import shutil | |
| from .config import Config | |
| from typing import List, Optional, Union, Dict, Any | |
| from . import utils | |
| import copy | |
| import traceback | |
| app = APIRouter() | |
| # === Configuration === | |
| IMAGE_ROOT = os.path.join(Config.current_path, "dataset/images") | |
| LABEL_ROOT = os.path.join(Config.current_path, "dataset/labels") | |
| IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels") | |
| CLASS_ID = 0 | |
| # === Pydantic Models === | |
| class Point(BaseModel): | |
| x: float | |
| y: float | |
| class Box(BaseModel): | |
| type: str = "bbox" # "bbox" or "segmentation" | |
| # For bbox | |
| left: Optional[int] = None | |
| top: Optional[int] = None | |
| width: Optional[int] = None | |
| height: Optional[int] = None | |
| # For segmentation | |
| points: Optional[List[Point]] = None | |
| # Common fields | |
| classId: int = CLASS_ID | |
| stroke: str = "#00ff00" | |
| strokeWidth: int = 3 | |
| fill: str = "rgba(0, 255, 0, 0.2)" | |
| saved: bool = True | |
| def round_floats(cls, v): | |
| return round(v) if v is not None else None | |
| class SaveAnnotationsRequest(BaseModel): | |
| annotations: List[Box] # Changed from 'boxes' to 'annotations' | |
| image_name: str | |
| original_width: int | |
| original_height: int | |
| class ImageInfo(BaseModel): | |
| name: str # Relative path like train/image1.jpg | |
| width: int | |
| height: int | |
| has_annotations: bool | |
| # === Helpers === | |
| def get_image_path(image_name: str) -> str: | |
| return os.path.join(IMAGE_ROOT, image_name) | |
| def get_label_path(image_name: str) -> str: | |
| return os.path.join(LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt") | |
| # === Core Functions === | |
| def load_yolo_annotations(image_path: str, label_path: str, detect: bool = False): | |
| """Load both bbox and segmentation annotations from YOLO format""" | |
| try: | |
| img = Image.open(image_path) | |
| w, h = img.size | |
| annotations = [] | |
| # Auto-detect if needed | |
| normalise = False | |
| if detect and not os.path.exists(label_path): | |
| from .yolo_manager import YOLOManager | |
| with YOLOManager() as yolo_manager: | |
| weights_path = Config.yolo_trained_model_path | |
| yolo_manager.load_model(weights_path) | |
| yolo_manager.annotate_images( | |
| image_paths=[image_path], | |
| output_dir=IMAGE_LABEL_ROOT, | |
| save_image=False, | |
| label_path=label_path | |
| ) | |
| normalise = True | |
| if os.path.exists(label_path): | |
| with open(label_path, "r") as f: | |
| for line in f: | |
| parts = list(map(float, line.strip().split())) | |
| if len(parts) < 5: | |
| continue | |
| class_id = int(parts[0]) | |
| if len(parts) == 5: # Bounding box format | |
| _, xc, yc, bw, bh = parts | |
| left = int((xc - bw / 2) * w) | |
| top = int((yc - bh / 2) * h) | |
| width = int(bw * w) | |
| height = int(bh * h) | |
| annotations.append({ | |
| "type": "bbox", | |
| "left": left, | |
| "top": top, | |
| "width": width, | |
| "height": height, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| }) | |
| elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation format | |
| # Skip class_id, then pairs of x,y coordinates | |
| coords = parts[1:] | |
| if len(coords) >= 6: # At least 3 points | |
| points = [] | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| x = coords[i] * w | |
| y = coords[i + 1] * h | |
| points.append({"x": x, "y": y}) | |
| annotations.append({ | |
| "type": "segmentation", | |
| "points": points, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| }) | |
| if normalise: | |
| annotations = utils.normalize_segmentation(annotations) | |
| save_yolo_annotations( | |
| copy.deepcopy(annotations), | |
| (w, h), | |
| label_path | |
| ) | |
| return annotations, (w, h) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading annotations: {str(e)} {traceback.format_exc()}") | |
| def normalize_annotations(annotations: List[Union[Box, dict]]) -> List[Box]: | |
| """Convert all annotations to Box objects.""" | |
| normalized = [] | |
| for ann in annotations: | |
| if isinstance(ann, Box): | |
| normalized.append(ann) | |
| elif isinstance(ann, dict): | |
| normalized.append(Box(**ann)) | |
| else: | |
| raise TypeError(f"Unsupported annotation type: {type(ann)}") | |
| return normalized | |
| def save_yolo_annotations(annotations: List[Box], original_size: tuple, label_path: str): | |
| """Save annotations in YOLO format (both bbox and segmentation)""" | |
| annotations = normalize_annotations(annotations) | |
| os.makedirs(os.path.dirname(label_path), exist_ok=True) | |
| w, h = original_size | |
| try: | |
| with open(label_path, "w") as f: | |
| # Generate YOLO format from annotations | |
| for annotation in annotations: | |
| if annotation.type == "bbox": | |
| left, top, width, height = annotation.left, annotation.top, annotation.width, annotation.height | |
| xc = (left + width / 2) / w | |
| yc = (top + height / 2) / h | |
| bw = width / w | |
| bh = height / h | |
| f.write(f"{annotation.classId} {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n") | |
| elif annotation.type == "segmentation" and annotation.points: | |
| # Convert points to normalized coordinates | |
| normalized_points = [] | |
| for point in annotation.points: | |
| normalized_points.extend([point.x / w, point.y / h]) | |
| coords_str = " ".join(f"{coord:.6f}" for coord in normalized_points) | |
| f.write(f"{annotation.classId} {coords_str}\n") | |
| # Copy to image_labels directory | |
| shutil.copy2(label_path, f"{IMAGE_LABEL_ROOT}/{os.path.basename(label_path)}") | |
| return True | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving annotations: {str(e)} {traceback.format_exc()}") | |
| def parse_yolo_line(line: str, image_width: int, image_height: int) -> Dict[str, Any]: | |
| """Parse a single YOLO format line and return annotation dict""" | |
| parts = list(map(float, line.strip().split())) | |
| if len(parts) < 5: | |
| return None | |
| class_id = int(parts[0]) | |
| if len(parts) == 5: # Bounding box | |
| _, xc, yc, bw, bh = parts | |
| left = int((xc - bw / 2) * image_width) | |
| top = int((yc - bh / 2) * image_height) | |
| width = int(bw * image_width) | |
| height = int(bh * image_height) | |
| return { | |
| "type": "bbox", | |
| "left": left, | |
| "top": top, | |
| "width": width, | |
| "height": height, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| } | |
| elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation | |
| coords = parts[1:] | |
| if len(coords) >= 6: # At least 3 points | |
| points = [] | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| x = coords[i] * image_width | |
| y = coords[i + 1] * image_height | |
| points.append({"x": x, "y": y}) | |
| return { | |
| "type": "segmentation", | |
| "points": points, | |
| "classId": class_id, | |
| "stroke": "#00ff00", | |
| "strokeWidth": 3, | |
| "fill": "rgba(0, 255, 0, 0.2)", | |
| "saved": True | |
| } | |
| return None | |
| # === API Routes === | |
| async def list_all_images(): | |
| image_info_list = [] | |
| for root, _, files in os.walk(IMAGE_ROOT): | |
| for file in sorted(files): | |
| if file.lower().endswith((".jpg", ".jpeg", ".png")): | |
| image_path = os.path.join(root, file) | |
| rel_path = os.path.relpath(image_path, IMAGE_ROOT) | |
| label_path = get_label_path(rel_path) | |
| img = Image.open(image_path) | |
| width, height = img.size | |
| image_info_list.append(ImageInfo( | |
| name=rel_path.replace("\\", "/"), | |
| width=width, | |
| height=height, | |
| has_annotations=os.path.exists(label_path) | |
| )) | |
| return image_info_list | |
| async def get_image(image_name: str): | |
| image_path = get_image_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| with Image.open(image_path) as img: | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| buffer = BytesIO() | |
| img.save(buffer, format="JPEG") | |
| img_data = base64.b64encode(buffer.getvalue()).decode() | |
| return { | |
| "image_data": f"data:image/jpeg;base64,{img_data}", | |
| "width": img.width, | |
| "height": img.height | |
| } | |
| async def get_annotations(image_name: str): | |
| image_path = get_image_path(image_name) | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| annotations, (width, height) = load_yolo_annotations(image_path, label_path) | |
| return { | |
| "annotations": annotations, | |
| "original_width": width, | |
| "original_height": height | |
| } | |
| async def get_detected_annotations(image_name: str): | |
| image_path = get_image_path(image_name) | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(image_path): | |
| raise HTTPException(status_code=404, detail="Image not found") | |
| annotations, (width, height) = load_yolo_annotations(image_path, label_path, True) | |
| return { | |
| "annotations": annotations, | |
| "original_width": width, | |
| "original_height": height | |
| } | |
| async def save_annotations(request: SaveAnnotationsRequest): | |
| label_path = get_label_path(request.image_name) | |
| success = save_yolo_annotations( | |
| request.annotations, | |
| (request.original_width, request.original_height), | |
| label_path | |
| ) | |
| return {"message": f"Saved {len(request.annotations)} annotations successfully"} | |
| async def delete_annotations(image_name: str): | |
| label_path = get_label_path(image_name) | |
| if os.path.exists(label_path): | |
| os.remove(label_path) | |
| return {"message": "Annotations deleted"} | |
| return {"message": "No annotations to delete"} | |
| async def download_annotations(image_name: str): | |
| label_path = get_label_path(image_name) | |
| if not os.path.exists(label_path): | |
| raise HTTPException(status_code=404, detail="Annotations not found") | |
| return FileResponse( | |
| label_path, | |
| media_type="text/plain", | |
| filename=os.path.basename(label_path) | |
| ) | |
| async def upload_image(file: UploadFile = File(...)): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| file_path = os.path.join(IMAGE_ROOT, "train", file.filename) | |
| with open(file_path, "wb") as f: | |
| f.write(await file.read()) | |
| return {"message": f"Uploaded {file.filename} to train set"} |