jebin2's picture
deploy issue fix
ca23a7f
from fastapi import APIRouter, HTTPException, UploadFile, File, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from .ws_manager import manager
from pydantic import BaseModel, field_validator
from typing import List
from PIL import Image, UnidentifiedImageError
import os
import base64
from io import BytesIO
import shutil
from typing import List, Optional, Union, Dict, Any
from . import utils
import copy
import traceback
import asyncio
import sys, signal
import psutil
import subprocess
from . import common
import fcntl
from .config import Config, load_config, update_toml_key
app = APIRouter()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
# Handle any websocket messages if needed
except WebSocketDisconnect:
print("Client disconnected:", websocket.client)
manager.disconnect(websocket)
# === Configuration ===
config = load_config()
IMAGE_LABEL_ROOT = os.path.join(config.current_path, "image_labels")
CLASS_ID = 0
# === Pydantic Models ===
class Point(BaseModel):
x: float
y: float
class Box(BaseModel):
type: str = "bbox" # "bbox" or "segmentation"
# For bbox
left: Optional[int] = None
top: Optional[int] = None
width: Optional[int] = None
height: Optional[int] = None
# For segmentation
points: Optional[List[Point]] = None
# Common fields
classId: int = CLASS_ID
stroke: str = "#00ff00"
strokeWidth: int = 3
fill: str = "rgba(0, 255, 0, 0.2)"
saved: bool = True
@field_validator("left", "top", "width", "height", mode="before")
def round_floats(cls, v):
return round(v) if v is not None else None
class SaveAnnotationsRequest(BaseModel):
annotations: List[Box] # Changed from 'boxes' to 'annotations'
image_name: str
original_width: int
original_height: int
class ImageInfo(BaseModel):
name: str # Relative path like train/image1.jpg
width: int
height: int
has_annotations: bool
class TrainConfig(BaseModel):
epoch: int # Relative path like train/image1.jpg
batch: int
imgsz: int
recreate_dataset: bool
resume_train: bool
# === Helpers ===
def get_image_path(image_name: str) -> str:
return os.path.join(config.IMAGE_SOURCE_PATH, image_name)
def get_label_path(image_name: str) -> str:
return os.path.join(IMAGE_LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt")
# === Core Functions ===
def load_yolo_annotations(image_path: str, label_path: str, detect: bool = False):
"""Load both bbox and segmentation annotations from YOLO format"""
try:
img = Image.open(image_path)
w, h = img.size
annotations = []
# Auto-detect if needed
normalise = False
if detect and not os.path.exists(label_path):
from .yolo_manager import YOLOManager
with YOLOManager() as yolo_manager:
weights_path = config.yolo_trained_model_path
yolo_manager.load_model(weights_path)
yolo_manager.annotate_images(
image_paths=[image_path],
output_dir=IMAGE_LABEL_ROOT,
save_image=False,
label_path=label_path
)
normalise = True
if os.path.exists(label_path):
with open(label_path, "r") as f:
for line in f:
parts = list(map(float, line.strip().split()))
if len(parts) < 5:
continue
class_id = int(parts[0])
if len(parts) == 5: # Bounding box format
_, xc, yc, bw, bh = parts
left = int((xc - bw / 2) * w)
top = int((yc - bh / 2) * h)
width = int(bw * w)
height = int(bh * h)
annotations.append({
"type": "bbox",
"left": left,
"top": top,
"width": width,
"height": height,
"classId": class_id,
"stroke": "#00ff00",
"strokeWidth": 3,
"fill": "rgba(0, 255, 0, 0.2)",
"saved": True
})
elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation format
# Skip class_id, then pairs of x,y coordinates
coords = parts[1:]
if len(coords) >= 6: # At least 3 points
points = []
for i in range(0, len(coords), 2):
if i + 1 < len(coords):
x = coords[i] * w
y = coords[i + 1] * h
points.append({"x": x, "y": y})
annotations.append({
"type": "segmentation",
"points": points,
"classId": class_id,
"stroke": "#00ff00",
"strokeWidth": 3,
"fill": "rgba(0, 255, 0, 0.2)",
"saved": True
})
if normalise:
annotations = utils.normalize_segmentation(annotations)
save_yolo_annotations(
copy.deepcopy(annotations),
(w, h),
label_path
)
return annotations, (w, h)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading annotations: {str(e)} {traceback.format_exc()}")
def normalize_annotations(annotations: List[Union[Box, dict]]) -> List[Box]:
"""Convert all annotations to Box objects."""
normalized = []
for ann in annotations:
if isinstance(ann, Box):
normalized.append(ann)
elif isinstance(ann, dict):
normalized.append(Box(**ann))
else:
raise TypeError(f"Unsupported annotation type: {type(ann)}")
return normalized
def save_yolo_annotations(annotations: List[Box], original_size: tuple, label_path: str):
"""Save annotations in YOLO format (both bbox and segmentation)"""
annotations = normalize_annotations(annotations)
os.makedirs(os.path.dirname(label_path), exist_ok=True)
w, h = original_size
try:
with open(label_path, "w") as f:
# Generate YOLO format from annotations
for annotation in annotations:
if annotation.type == "bbox":
left, top, width, height = annotation.left, annotation.top, annotation.width, annotation.height
xc = (left + width / 2) / w
yc = (top + height / 2) / h
bw = width / w
bh = height / h
f.write(f"{annotation.classId} {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")
elif annotation.type == "segmentation" and annotation.points:
# Convert points to normalized coordinates
normalized_points = []
for point in annotation.points:
normalized_points.extend([point.x / w, point.y / h])
coords_str = " ".join(f"{coord:.6f}" for coord in normalized_points)
f.write(f"{annotation.classId} {coords_str}\n")
# Copy to image_labels directory
shutil.copy2(label_path, f"{IMAGE_LABEL_ROOT}/{os.path.basename(label_path)}")
return True
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving annotations: {str(e)} {traceback.format_exc()}")
def parse_yolo_line(line: str, image_width: int, image_height: int) -> Dict[str, Any]:
"""Parse a single YOLO format line and return annotation dict"""
parts = list(map(float, line.strip().split()))
if len(parts) < 5:
return None
class_id = int(parts[0])
if len(parts) == 5: # Bounding box
_, xc, yc, bw, bh = parts
left = int((xc - bw / 2) * image_width)
top = int((yc - bh / 2) * image_height)
width = int(bw * image_width)
height = int(bh * image_height)
return {
"type": "bbox",
"left": left,
"top": top,
"width": width,
"height": height,
"classId": class_id,
"stroke": "#00ff00",
"strokeWidth": 3,
"fill": "rgba(0, 255, 0, 0.2)",
"saved": True
}
elif len(parts) > 5 and len(parts) % 2 == 1: # Segmentation
coords = parts[1:]
if len(coords) >= 6: # At least 3 points
points = []
for i in range(0, len(coords), 2):
if i + 1 < len(coords):
x = coords[i] * image_width
y = coords[i + 1] * image_height
points.append({"x": x, "y": y})
return {
"type": "segmentation",
"points": points,
"classId": class_id,
"stroke": "#00ff00",
"strokeWidth": 3,
"fill": "rgba(0, 255, 0, 0.2)",
"saved": True
}
return None
# === API Routes ===
@app.get("/api/annotate/images", response_model=List[ImageInfo])
async def list_all_images():
image_info_list = []
for root, _, files in os.walk(config.IMAGE_SOURCE_PATH):
for file in sorted(files):
if file.lower().endswith((".jpg", ".jpeg", ".png")):
try:
image_path = os.path.join(root, file)
rel_path = os.path.relpath(image_path, config.IMAGE_SOURCE_PATH)
label_path = get_label_path(rel_path)
img = Image.open(image_path)
width, height = img.size
image_info_list.append(ImageInfo(
name=rel_path.replace("\\", "/"),
width=width,
height=height,
has_annotations=os.path.exists(label_path)
))
except UnidentifiedImageError:
print(f"Cannot identify image file: {image_path}")
return image_info_list
@app.get("/api/annotate/image/{image_name:path}")
async def get_image(image_name: str):
image_path = get_image_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
with Image.open(image_path) as img:
if img.mode != "RGB":
img = img.convert("RGB")
buffer = BytesIO()
img.save(buffer, format="JPEG")
img_data = base64.b64encode(buffer.getvalue()).decode()
return {
"image_data": f"data:image/jpeg;base64,{img_data}",
"width": img.width,
"height": img.height
}
@app.get("/api/annotate/annotations/{image_name:path}")
async def get_annotations(image_name: str):
image_path = get_image_path(image_name)
label_path = get_label_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
annotations, (width, height) = load_yolo_annotations(image_path, label_path)
return {
"annotations": annotations,
"original_width": width,
"original_height": height
}
@app.get("/api/annotate/detect_annotations/{image_name:path}")
async def get_detected_annotations(image_name: str):
image_path = get_image_path(image_name)
label_path = get_label_path(image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
annotations, (width, height) = load_yolo_annotations(image_path, label_path, True)
return {
"annotations": annotations,
"original_width": width,
"original_height": height
}
@app.post("/api/annotate/annotations")
async def save_annotations(request: SaveAnnotationsRequest):
label_path = get_label_path(request.image_name)
success = save_yolo_annotations(
request.annotations,
(request.original_width, request.original_height),
label_path
)
return {"message": f"Saved {len(request.annotations)} annotations successfully"}
@app.delete("/api/annotate/annotations/{image_name:path}")
async def delete_annotations(image_name: str):
label_path = get_label_path(image_name)
if os.path.exists(label_path):
os.remove(label_path)
return {"message": "Annotations deleted"}
return {"message": "No annotations to delete"}
@app.get("/api/annotate/annotations/{image_name:path}/download")
async def download_annotations(image_name: str):
label_path = get_label_path(image_name)
if not os.path.exists(label_path):
raise HTTPException(status_code=404, detail="Annotations not found")
return FileResponse(
label_path,
media_type="text/plain",
filename=os.path.basename(label_path)
)
@app.post("/api/annotate/upload")
async def upload_image(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
file_path = os.path.join(config.IMAGE_SOURCE_PATH, file.filename)
with open(file_path, "wb") as f:
f.write(await file.read())
return {"message": f"Uploaded {file.filename} to train set"}
####################### ----train---- #############################
current_process = {}
def reset_current_process():
global current_process
current_process = {
"process": None
}
reset_current_process()
# Define a function to handle cleanup
def handle_exit(signal_received, frame):
if current_process["process"]:
os.killpg(os.getpgid(current_process['process'].pid), signal.SIGKILL)
sys.exit(0)
# Register the signal handler for SIGINT
signal.signal(signal.SIGINT, handle_exit)
@app.get("/api/annotate/train/config")
async def get_config():
return {
"epoch": config.EPOCH,
"imgsz": config.DEFAULT_IMAGE_SIZE,
"batch": config.BATCH,
"resume_train": config.RESUME_TRAIN,
"recreate_dataset": config.RECREATE_DATASET
}
@app.post("/api/annotate/train/config")
async def save_config(request: TrainConfig):
update_toml_key("EPOCH", request.epoch)
update_toml_key("BATCH", request.batch)
update_toml_key("DEFAULT_IMAGE_SIZE", request.imgsz)
update_toml_key("RECREATE_DATASET", request.recreate_dataset)
update_toml_key("RESUME_TRAIN", request.resume_train)
return {'message': 'Config update successfully.', 'status': 'success'}
@app.post("/api/annotate/model_reset")
async def reset_model():
from pathlib import Path
file_path = Path(config.yolo_trained_model_path)
if file_path.exists():
file_path.unlink()
return {'message': 'Model Reseted', 'status': 'success'}
@app.get("/api/annotate/deploy")
async def deploy_model(app_name: str):
from .yolo_manager import YOLOManager
with YOLOManager() as yolo_manager:
yolo_manager.deploy()
return {'message': 'Model Deployed', 'status': 'success'}
@app.get("/api/annotate/train")
async def upload_image():
os.environ['PYTHONUNBUFFERED'] = "1"
# Skip if the training process is already running
if is_process_running("comic_panel_extractor.train"):
return {"status": "ignored", "message": "Training already in progress."}
reset_current_process()
cmd_to_run=""
if config.RECREATE_DATASET:
cmd_to_run = "python -m comic_panel_extractor.create_dataset && "
cmd_to_run += "python -m comic_panel_extractor.train"
async def run_and_stream_output():
process = None
try:
process = subprocess.Popen(
cmd_to_run,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
universal_newlines=True,
preexec_fn=os.setsid,
env={**os.environ, 'PYTHONUNBUFFERED': '1', 'CUDA_LAUNCH_BLOCKING': '1', 'USE_CPU_IF_POSSIBLE': str(common.get_device() == "cpu")}
)
# Set non-blocking I/O
fd = process.stdout.fileno()
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
current_process['process'] = process
# Stream the output and send it via WebSocket in real-time
while True:
try:
output = process.stdout.readline()
if output:
print(output.strip())
print("Active connections:", len(manager.active_connections))
asyncio.create_task(manager.broadcast({
'type': 'command_output',
'data': output.strip()
}))
sys.stdout.flush()
if process.poll() is not None:
break
# Small delay to prevent CPU spinning
await asyncio.sleep(0.01)
except Exception as e:
print(f"Error reading process output: {e}")
break
# Process finished
return_code = process.returncode if process else -1
asyncio.create_task(manager.broadcast({
'type': 'command_finished',
'return_code': return_code
}))
except Exception as e:
print(f"Error in run_and_stream_output: {e}")
asyncio.create_task(manager.broadcast({
'type': 'command_error',
'error': str(e)
}))
finally:
current_process['process'] = None
# Start the command execution in a separate task
asyncio.create_task(run_and_stream_output())
return {"message": "Command started!", "status": "started"}
@app.get("/api/annotate/stopTrain")
async def stop_train():
try:
# Check if there's actually a process to stop
if current_process['process'] is None:
return {'message': 'No command is currently running.', 'status': 'no_process'}
# Check if process has already terminated naturally
if current_process['process'].poll() is not None:
# Process already finished, just clean up
reset_current_process()
return {'message': 'Command has already finished.', 'status': 'already_finished'}
try:
# Get the process group ID before attempting to kill
pgid = os.getpgid(current_process['process'].pid)
# Kill the entire process group
os.killpg(pgid, signal.SIGTERM) # Try SIGTERM first
# Wait a bit for graceful shutdown
await asyncio.sleep(1)
# If still running, force kill
if current_process['process'] and current_process['process'].poll() is None:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
# Process already dead
print("Process already terminated")
except OSError as e:
# Handle permission errors or other OS-level issues
print(f"Error terminating process: {e}")
# Try to kill just the main process if group kill fails
try:
current_process['process'].terminate()
await asyncio.sleep(0.5)
if current_process['process'].poll() is None:
current_process['process'].kill()
except:
pass
# Always reset the process state
reset_current_process()
# Notify connected clients
await manager.broadcast({
'type': 'command_stopped',
'message': 'Command terminated by user'
})
return {'message': 'Command terminated successfully.', 'status': 'terminated'}
except Exception as e:
print(f"Error in stop_command: {str(e)}")
# Force reset even if there was an error
reset_current_process()
raise HTTPException(status_code=500, detail=f'Error stopping command: {str(e)}')
def is_process_running(name: str) -> bool:
"""
Check if a process containing 'name' in its command line is running.
"""
for proc in psutil.process_iter(['cmdline']):
try:
cmdline = " ".join(proc.info['cmdline']) if proc.info['cmdline'] else ""
if name in cmdline:
return True
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return False