jebin2's picture
path isue fix
c13ce0c
raw
history blame
7.93 kB
from fastapi import APIRouter, HTTPException, UploadFile, File
from fastapi.responses import FileResponse
from pydantic import BaseModel, field_validator
from typing import List
from PIL import Image
import os
import base64
from io import BytesIO
import shutil
from .config import Config
app = APIRouter()
# === Configuration ===
IMAGE_ROOT = os.path.join(Config.current_path, "dataset/images")
LABEL_ROOT = os.path.join(Config.current_path, "dataset/labels")
IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels")
CLASS_ID = 0
# === Pydantic Models ===
class Box(BaseModel):
left: int
top: int
width: int
height: int
type: str = "rect"
stroke: str = "#00ff00"
strokeWidth: int = 3
fill: str = "rgba(0, 255, 0, 0.2)"
saved: bool = True
@field_validator("left", "top", "width", "height", mode="before")
def round_floats(cls, v):
return round(v)
class SaveAnnotationsRequest(BaseModel):
boxes: List[Box]
image_name: str # Relative path like train/image1.jpg
original_width: int
original_height: int
class ImageInfo(BaseModel):
name: str # Relative path like train/image1.jpg
width: int
height: int
has_annotations: bool
# === Helpers ===
def get_image_path(image_name: str) -> str:
return os.path.join(IMAGE_ROOT, image_name)
def get_label_path(image_name: str) -> str:
return os.path.join(LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt")
# === Core Functions ===
def load_yolo_boxes(image_path: str, label_path: str, detect: bool = False):
try:
img = Image.open(image_path)
w, h = img.size
boxes = []
if detect and not os.path.exists(label_path):
from .yolo_manager import YOLOManager
with YOLOManager() as yolo_manager:
weights_path = f'{Config.current_path}/{Config.YOLO_MODEL_NAME}.pt'
yolo_manager.load_model(weights_path)
# Run inference
_, label_path = yolo_manager.annotate_images(image_paths=[image_path], output_dir=IMAGE_LABEL_ROOT, save_image=False, label_path=label_path)
if os.path.exists(label_path):
with open(label_path, "r") as f:
for line in f:
parts = list(map(float, line.strip().split()))
if len(parts) != 5:
continue
_, xc, yc, bw, bh = parts
left = int((xc - bw / 2) * w)
top = int((yc - bh / 2) * h)
width = int(bw * w)
height = int(bh * h)
boxes.append({
"type": "rect",
"left": left,
"top": top,
"width": width,
"height": height,
"stroke": "#00ff00",
"strokeWidth": 3,
"fill": "rgba(0, 255, 0, 0.2)",
"saved": True
})
return boxes, (w, h)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading data: {str(e)}")
def save_yolo_annotations(boxes: List[Box], original_size: tuple, label_path: str):
os.makedirs(os.path.dirname(label_path), exist_ok=True)
w, h = original_size
try:
with open(label_path, "w") as f:
for box in boxes:
left, top, width, height = box.left, box.top, box.width, box.height
xc = (left + width / 2) / w
yc = (top + height / 2) / h
bw = width / w
bh = height / h
f.write(f"{CLASS_ID} {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")
shutil.copy2(label_path, f"{IMAGE_LABEL_ROOT}/{os.path.basename(label_path)}")
return True
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving annotations: {str(e)}")
# === API Routes ===
@app.get("/api/annotate/images", response_model=List[ImageInfo])
async def list_all_images():
image_info_list = []
for root, _, files in os.walk(IMAGE_ROOT):
for file in files:
if file.lower().endswith((".jpg", ".jpeg", ".png")):
image_path = os.path.join(root, file)
rel_path = os.path.relpath(image_path, IMAGE_ROOT)
label_path = get_label_path(rel_path)
img = Image.open(image_path)
width, height = img.size
image_info_list.append(ImageInfo(
name=rel_path.replace("\\", "/"),
width=width,
height=height,
has_annotations=os.path.exists(label_path)
))
return image_info_list
@app.get("/api/annotate/image/{image_name:path}")
async def get_image(image_name: str):
image_path = get_image_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
with Image.open(image_path) as img:
if img.mode != "RGB":
img = img.convert("RGB")
buffer = BytesIO()
img.save(buffer, format="JPEG")
img_data = base64.b64encode(buffer.getvalue()).decode()
return {
"image_data": f"data:image/jpeg;base64,{img_data}",
"width": img.width,
"height": img.height
}
@app.get("/api/annotate/annotations/{image_name:path}")
async def get_annotations(image_name: str):
image_path = get_image_path(image_name)
label_path = get_label_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
boxes, (width, height) = load_yolo_boxes(image_path, label_path)
return {
"boxes": boxes,
"original_width": width,
"original_height": height
}
@app.get("/api/annotate/detect_annotations/{image_name:path}")
async def get_annotations(image_name: str):
image_path = get_image_path(image_name)
label_path = get_label_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
boxes, (width, height) = load_yolo_boxes(image_path, label_path, True)
return {
"boxes": boxes,
"original_width": width,
"original_height": height
}
@app.post("/api/annotate/annotations")
async def save_annotations(request: SaveAnnotationsRequest):
label_path = get_label_path(request.image_name)
success = save_yolo_annotations(
request.boxes,
(request.original_width, request.original_height),
label_path
)
return {"message": f"Saved {len(request.boxes)} annotations successfully"}
@app.delete("/api/annotate/annotations/{image_name:path}")
async def delete_annotations(image_name: str):
label_path = get_label_path(image_name)
if os.path.exists(label_path):
os.remove(label_path)
return {"message": "Annotations deleted"}
return {"message": "No annotations to delete"}
@app.get("/api/annotate/annotations/{image_name:path}/download")
async def download_annotations(image_name: str):
label_path = get_label_path(image_name)
if not os.path.exists(label_path):
raise HTTPException(status_code=404, detail="Annotations not found")
return FileResponse(
label_path,
media_type="text/plain",
filename=os.path.basename(label_path)
)
@app.post("/api/annotate/upload")
async def upload_image(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
file_path = os.path.join(IMAGE_ROOT, "train", file.filename)
with open(file_path, "wb") as f:
f.write(await file.read())
return {"message": f"Uploaded {file.filename} to train set"}