jebin2 commited on
Commit
4ecf9fd
Β·
1 Parent(s): bf00ae9

train config added

Browse files
comic_panel_extractor/annorator_server.py CHANGED
@@ -8,7 +8,6 @@ import os
8
  import base64
9
  from io import BytesIO
10
  import shutil
11
- from .config import Config
12
  from typing import List, Optional, Union, Dict, Any
13
  from . import utils
14
  import copy
@@ -19,6 +18,7 @@ import psutil
19
  import subprocess
20
  from . import common
21
  import fcntl
 
22
 
23
  app = APIRouter()
24
 
@@ -34,7 +34,8 @@ async def websocket_endpoint(websocket: WebSocket):
34
  manager.disconnect(websocket)
35
 
36
  # === Configuration ===
37
- IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels")
 
38
 
39
  CLASS_ID = 0
40
 
@@ -75,9 +76,16 @@ class ImageInfo(BaseModel):
75
  height: int
76
  has_annotations: bool
77
 
 
 
 
 
 
 
 
78
  # === Helpers ===
79
  def get_image_path(image_name: str) -> str:
80
- return os.path.join(Config.IMAGE_SOURCE_PATH, image_name)
81
 
82
  def get_label_path(image_name: str) -> str:
83
  return os.path.join(IMAGE_LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt")
@@ -95,7 +103,7 @@ def load_yolo_annotations(image_path: str, label_path: str, detect: bool = False
95
  if detect and not os.path.exists(label_path):
96
  from .yolo_manager import YOLOManager
97
  with YOLOManager() as yolo_manager:
98
- weights_path = Config.yolo_trained_model_path
99
  yolo_manager.load_model(weights_path)
100
  yolo_manager.annotate_images(
101
  image_paths=[image_path],
@@ -265,12 +273,12 @@ def parse_yolo_line(line: str, image_width: int, image_height: int) -> Dict[str,
265
  @app.get("/api/annotate/images", response_model=List[ImageInfo])
266
  async def list_all_images():
267
  image_info_list = []
268
- for root, _, files in os.walk(Config.IMAGE_SOURCE_PATH):
269
  for file in sorted(files):
270
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
271
  try:
272
  image_path = os.path.join(root, file)
273
- rel_path = os.path.relpath(image_path, Config.IMAGE_SOURCE_PATH)
274
  label_path = get_label_path(rel_path)
275
 
276
  img = Image.open(image_path)
@@ -370,7 +378,7 @@ async def upload_image(file: UploadFile = File(...)):
370
  if not file.content_type.startswith("image/"):
371
  raise HTTPException(status_code=400, detail="File must be an image")
372
 
373
- file_path = os.path.join(Config.IMAGE_SOURCE_PATH, file.filename)
374
  with open(file_path, "wb") as f:
375
  f.write(await file.read())
376
  return {"message": f"Uploaded {file.filename} to train set"}
@@ -396,16 +404,36 @@ def handle_exit(signal_received, frame):
396
 
397
  # Register the signal handler for SIGINT
398
  signal.signal(signal.SIGINT, handle_exit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
  @app.get("/api/annotate/train")
401
- async def upload_image(recreate_dataset: bool = False):
402
  os.environ['PYTHONUNBUFFERED'] = "1"
403
  # Skip if the training process is already running
404
  if is_process_running("comic_panel_extractor.train"):
405
  return {"status": "ignored", "message": "Training already in progress."}
406
  reset_current_process()
407
  cmd_to_run=""
408
- if recreate_dataset:
409
  cmd_to_run = "python -m comic_panel_extractor.create_dataset && "
410
  cmd_to_run += "python -m comic_panel_extractor.train"
411
 
 
8
  import base64
9
  from io import BytesIO
10
  import shutil
 
11
  from typing import List, Optional, Union, Dict, Any
12
  from . import utils
13
  import copy
 
18
  import subprocess
19
  from . import common
20
  import fcntl
21
+ from .config import load_config, update_toml_key
22
 
23
  app = APIRouter()
24
 
 
34
  manager.disconnect(websocket)
35
 
36
  # === Configuration ===
37
+ config = load_config()
38
+ IMAGE_LABEL_ROOT = os.path.join(config.current_path, "image_labels")
39
 
40
  CLASS_ID = 0
41
 
 
76
  height: int
77
  has_annotations: bool
78
 
79
+ class TrainConfig(BaseModel):
80
+ epoch: int # Relative path like train/image1.jpg
81
+ batch: int
82
+ imgsz: int
83
+ recreate_dataset: bool
84
+ resume_train: bool
85
+
86
  # === Helpers ===
87
  def get_image_path(image_name: str) -> str:
88
+ return os.path.join(config.IMAGE_SOURCE_PATH, image_name)
89
 
90
  def get_label_path(image_name: str) -> str:
91
  return os.path.join(IMAGE_LABEL_ROOT, os.path.splitext(image_name)[0] + ".txt")
 
103
  if detect and not os.path.exists(label_path):
104
  from .yolo_manager import YOLOManager
105
  with YOLOManager() as yolo_manager:
106
+ weights_path = config.yolo_trained_model_path
107
  yolo_manager.load_model(weights_path)
108
  yolo_manager.annotate_images(
109
  image_paths=[image_path],
 
273
  @app.get("/api/annotate/images", response_model=List[ImageInfo])
274
  async def list_all_images():
275
  image_info_list = []
276
+ for root, _, files in os.walk(config.IMAGE_SOURCE_PATH):
277
  for file in sorted(files):
278
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
279
  try:
280
  image_path = os.path.join(root, file)
281
+ rel_path = os.path.relpath(image_path, config.IMAGE_SOURCE_PATH)
282
  label_path = get_label_path(rel_path)
283
 
284
  img = Image.open(image_path)
 
378
  if not file.content_type.startswith("image/"):
379
  raise HTTPException(status_code=400, detail="File must be an image")
380
 
381
+ file_path = os.path.join(config.IMAGE_SOURCE_PATH, file.filename)
382
  with open(file_path, "wb") as f:
383
  f.write(await file.read())
384
  return {"message": f"Uploaded {file.filename} to train set"}
 
404
 
405
  # Register the signal handler for SIGINT
406
  signal.signal(signal.SIGINT, handle_exit)
407
+ @app.get("/api/annotate/train/config")
408
+ async def get_config():
409
+ return {
410
+ "epoch": config.EPOCH,
411
+ "imgsz": config.DEFAULT_IMAGE_SIZE,
412
+ "batch": config.BATCH,
413
+ "resume_train": config.RESUME_TRAIN,
414
+ "recreate_dataset": config.RECREATE_DATASET
415
+ }
416
+
417
+ @app.post("/api/annotate/train/config")
418
+ async def save_config(request: TrainConfig):
419
+ update_toml_key("EPOCH", request.epoch)
420
+ update_toml_key("BATCH", request.batch)
421
+ update_toml_key("DEFAULT_IMAGE_SIZE", request.imgsz)
422
+ update_toml_key("RECREATE_DATASET", request.recreate_dataset)
423
+ update_toml_key("RESUME_TRAIN", request.resume_train)
424
+
425
+ return {'message': 'Config update successfully.', 'status': 'success'}
426
+
427
 
428
  @app.get("/api/annotate/train")
429
+ async def upload_image():
430
  os.environ['PYTHONUNBUFFERED'] = "1"
431
  # Skip if the training process is already running
432
  if is_process_running("comic_panel_extractor.train"):
433
  return {"status": "ignored", "message": "Training already in progress."}
434
  reset_current_process()
435
  cmd_to_run=""
436
+ if config.RECREATE_DATASET:
437
  cmd_to_run = "python -m comic_panel_extractor.create_dataset && "
438
  cmd_to_run += "python -m comic_panel_extractor.train"
439
 
comic_panel_extractor/config.py CHANGED
@@ -1,62 +1,125 @@
1
  from dataclasses import dataclass
2
  import os
3
  import toml
4
-
5
  from dotenv import load_dotenv
 
6
  load_dotenv()
7
 
8
  CURRENT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__)))
9
- CONFIG_FILE = f"{CURRENT_PATH}/config.toml"
10
 
11
- # Load TOML config
12
- if os.path.exists(CONFIG_FILE):
13
- config_data = toml.load(CONFIG_FILE)
14
- else:
15
- raise FileNotFoundError(f"Config file not found: {CONFIG_FILE}")
16
 
17
  @dataclass
18
  class Config:
19
- """Configuration settings for the comic-to-video pipeline."""
20
- current_path: str = CURRENT_PATH
21
-
22
- # Read from TOML config
23
- EPOCH: int = int(config_data.get("EPOCH", 200))
24
- DEFAULT_IMAGE_SIZE: int = int(config_data.get("DEFAULT_IMAGE_SIZE", 640))
25
- BATCH: int = int(config_data.get("BATCH", 10))
26
- RESUME_TRAIN: bool = str(config_data.get("RESUME_TRAIN", "True")).lower() in ("1", "true", "yes")
27
- YOLO_BASE_MODEL_NAME: str = config_data.get("YOLO_BASE_MODEL_NAME", "yolo11s-seg")
28
- YOLO_MODEL_NAME: str = config_data.get("YOLO_MODEL_NAME", f"comic_panel_{YOLO_BASE_MODEL_NAME}")
29
- image_path_from_config = config_data.get("IMAGE_SOURCE_PATH", "")
30
- # Ensure absolute path
31
- IMAGE_SOURCE_PATH: str = (
32
- image_path_from_config
33
- if os.path.isabs(image_path_from_config)
34
- else os.path.join(CURRENT_PATH, image_path_from_config)
35
- )
36
-
37
- # Derived paths
38
- yolo_base_model_path: str = f"{current_path}/{YOLO_BASE_MODEL_NAME}.pt"
39
- yolo_trained_model_path: str = f"{current_path}/{YOLO_MODEL_NAME}.pt"
40
-
41
- # Other parameters
42
- org_input_path: str = ""
43
- input_path: str = ""
44
- black_overlay_input_path: str = ""
45
- output_folder: str = "temp_dir"
46
- distance_threshold: int = 70
47
- vertical_threshold: int = 30
48
- text_cood_file_name: str = "detect_and_group_text.json"
49
- min_text_length: int = 2
50
- min_area_ratio: float = 0.05
51
- min_width_ratio: float = 0.15
52
- min_height_ratio: float = 0.15
53
-
54
- # Additional parameters for BorderPanelExtractor
55
- panel_filename_pattern: str = r"panel_\d+_\((\d+), (\d+), (\d+), (\d+)\)\.jpg"
56
-
57
- # Static constants
58
- SUPPORTED_EXTENSIONS: list = ('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')
59
-
60
- def get_text_cood_file_path(config: Config):
61
- """Return full path to text coordinate file."""
62
- return f"{config.output_folder}/{config.text_cood_file_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
  import os
3
  import toml
 
4
  from dotenv import load_dotenv
5
+
6
  load_dotenv()
7
 
8
  CURRENT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__)))
9
+ CONFIG_FILE = os.path.join(CURRENT_PATH, "config.toml")
10
 
 
 
 
 
 
11
 
12
  @dataclass
13
  class Config:
14
+ """Configuration settings for the comic-to-video pipeline."""
15
+
16
+ # Paths
17
+ current_path: str = CURRENT_PATH
18
+ config_path: str = CONFIG_FILE
19
+
20
+ # Core settings
21
+ EPOCH: int = 200
22
+ DEFAULT_IMAGE_SIZE: int = 640
23
+ BATCH: int = 10
24
+ RESUME_TRAIN: bool = True
25
+ RECREATE_DATASET: bool = True
26
+
27
+ # YOLO models
28
+ YOLO_BASE_MODEL_NAME: str = "yolo11s-seg"
29
+ YOLO_MODEL_NAME: str = "" # will be derived if empty
30
+ IMAGE_SOURCE_PATH: str = ""
31
+
32
+ # Derived paths
33
+ yolo_base_model_path: str = ""
34
+ yolo_trained_model_path: str = ""
35
+
36
+ # Pipeline parameters
37
+ org_input_path: str = ""
38
+ input_path: str = ""
39
+ black_overlay_input_path: str = ""
40
+ output_folder: str = "temp_dir"
41
+ distance_threshold: int = 70
42
+ vertical_threshold: int = 30
43
+ text_cood_file_name: str = "detect_and_group_text.json"
44
+ min_text_length: int = 2
45
+ min_area_ratio: float = 0.05
46
+ min_width_ratio: float = 0.15
47
+ min_height_ratio: float = 0.15
48
+
49
+ # BorderPanelExtractor
50
+ panel_filename_pattern: str = r"panel_\d+_\((\d+), (\d+), (\d+), (\d+)\)\.jpg"
51
+
52
+ # Constants
53
+ SUPPORTED_EXTENSIONS: tuple = ('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')
54
+
55
+ def __post_init__(self):
56
+ # Ensure absolute IMAGE_SOURCE_PATH
57
+ if self.IMAGE_SOURCE_PATH:
58
+ if not os.path.isabs(self.IMAGE_SOURCE_PATH):
59
+ self.IMAGE_SOURCE_PATH = os.path.join(self.current_path, self.IMAGE_SOURCE_PATH)
60
+
61
+ # Derive YOLO_MODEL_NAME if empty
62
+ if not self.YOLO_MODEL_NAME:
63
+ self.YOLO_MODEL_NAME = f"comic_panel_{self.YOLO_BASE_MODEL_NAME}"
64
+
65
+ # Derived paths
66
+ self.yolo_base_model_path = os.path.join(self.current_path, f"{self.YOLO_BASE_MODEL_NAME}.pt")
67
+ self.yolo_trained_model_path = os.path.join(self.current_path, f"{self.YOLO_MODEL_NAME}.pt")
68
+
69
+
70
+ def load_config(file_path=CONFIG_FILE) -> Config:
71
+ """Load the latest config from TOML file and return a Config instance."""
72
+ if not os.path.exists(file_path):
73
+ raise FileNotFoundError(f"Config file not found: {file_path}")
74
+
75
+ data = toml.load(file_path)
76
+
77
+ # Convert boolean strings to actual bool
78
+ def to_bool(val):
79
+ if isinstance(val, bool):
80
+ return val
81
+ return str(val).lower() in ("1", "true", "yes")
82
+
83
+ return Config(
84
+ EPOCH=int(data.get("EPOCH", 200)),
85
+ DEFAULT_IMAGE_SIZE=int(data.get("DEFAULT_IMAGE_SIZE", 640)),
86
+ BATCH=int(data.get("BATCH", 10)),
87
+ RESUME_TRAIN=to_bool(data.get("RESUME_TRAIN", True)),
88
+ RECREATE_DATASET=to_bool(data.get("RECREATE_DATASET", True)),
89
+ YOLO_BASE_MODEL_NAME=data.get("YOLO_BASE_MODEL_NAME", "yolo11s-seg"),
90
+ YOLO_MODEL_NAME=data.get("YOLO_MODEL_NAME", ""), # derived in __post_init__
91
+ IMAGE_SOURCE_PATH=data.get("IMAGE_SOURCE_PATH", "")
92
+ )
93
+
94
+
95
+ def update_toml_key(key: str, value, file_path=CONFIG_FILE) -> Config:
96
+ """Update a key in the TOML file and reload config."""
97
+ if not os.path.exists(file_path):
98
+ raise FileNotFoundError(f"Config file not found: {file_path}")
99
+
100
+ data = toml.load(file_path)
101
+ data[key] = value
102
+ with open(file_path, "w") as f:
103
+ toml.dump(data, f)
104
+
105
+ # Reload and return new Config
106
+ return load_config(file_path)
107
+
108
+
109
+ def get_text_cood_file_path(config: Config) -> str:
110
+ """Return full path to text coordinate file."""
111
+ return os.path.join(config.output_folder, config.text_cood_file_name)
112
+
113
+
114
+ # Example usage:
115
+ if __name__ == "__main__":
116
+ # Load config
117
+ config = load_config()
118
+ print("EPOCH:", config.EPOCH)
119
+
120
+ # Update TOML key and reload
121
+ config = update_toml_key("EPOCH", 500)
122
+ print("Updated EPOCH:", config.EPOCH)
123
+
124
+ # Get text coord file path
125
+ print("Text coord path:", get_text_cood_file_path(config))
comic_panel_extractor/config.toml CHANGED
@@ -1,7 +1,8 @@
1
- EPOCH=200
2
- DEFAULT_IMAGE_SIZE=640
3
- BATCH=10
4
- RESUME_TRAIN="true"
5
- YOLO_BASE_MODEL_NAME="yolo11s-seg"
6
- YOLO_MODEL_NAME="comic_panel_yolo11s-seg"
7
- IMAGE_SOURCE_PATH="images"
 
 
1
+ EPOCH = 200
2
+ DEFAULT_IMAGE_SIZE = 640
3
+ BATCH = 10
4
+ RESUME_TRAIN = true
5
+ RECREATE_DATASET = true
6
+ YOLO_BASE_MODEL_NAME = "yolo11s-seg"
7
+ YOLO_MODEL_NAME = "comic_panel_yolo11s-seg"
8
+ IMAGE_SOURCE_PATH = "images"
comic_panel_extractor/config.toml.bak ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ EPOCH=200
2
+ DEFAULT_IMAGE_SIZE=640
3
+ BATCH=10
4
+ RESUME_TRAIN="true"
5
+ YOLO_BASE_MODEL_NAME="yolo11s-seg"
6
+ YOLO_MODEL_NAME="comic_panel_yolo11s-seg"
7
+ IMAGE_SOURCE_PATH="images"
comic_panel_extractor/create_dataset.py CHANGED
@@ -4,10 +4,11 @@ import random
4
  from pathlib import Path
5
  from dotenv import load_dotenv
6
  from tqdm import tqdm
7
- from .config import Config
8
 
9
  load_dotenv()
10
- SOURCE_PATHS = Config.IMAGE_SOURCE_PATH
 
11
 
12
  if not SOURCE_PATHS:
13
  raise ValueError("SOURCE_PATH not set")
@@ -15,8 +16,8 @@ if not SOURCE_PATHS:
15
  # Split by comma and strip whitespace
16
  source_paths = [Path(p.strip()) for p in SOURCE_PATHS.split(',')]
17
 
18
- images_dir = Path(f'{Config.current_path}/images')
19
- dataset_dir = Path(f'{Config.current_path}/dataset')
20
 
21
  image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
22
  label_exts = {'.txt'}
@@ -72,7 +73,7 @@ splits = {
72
  'test': all_images[val_end:]
73
  }
74
 
75
- label_src_dir = Path(f'{Config.current_path}/image_labels')
76
 
77
  # Move/copy images and labels to their split folders with tqdm
78
  for split, files in splits.items():
 
4
  from pathlib import Path
5
  from dotenv import load_dotenv
6
  from tqdm import tqdm
7
+ from .config import load_config
8
 
9
  load_dotenv()
10
+ config = load_config()
11
+ SOURCE_PATHS = config.IMAGE_SOURCE_PATH
12
 
13
  if not SOURCE_PATHS:
14
  raise ValueError("SOURCE_PATH not set")
 
16
  # Split by comma and strip whitespace
17
  source_paths = [Path(p.strip()) for p in SOURCE_PATHS.split(',')]
18
 
19
+ images_dir = Path(f'{config.current_path}/images')
20
+ dataset_dir = Path(f'{config.current_path}/dataset')
21
 
22
  image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'}
23
  label_exts = {'.txt'}
 
73
  'test': all_images[val_end:]
74
  }
75
 
76
+ label_src_dir = Path(f'{config.current_path}/image_labels')
77
 
78
  # Move/copy images and labels to their split folders with tqdm
79
  for split, files in splits.items():
comic_panel_extractor/extractor_server.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import APIRouter, File, UploadFile, HTTPException
2
  from fastapi.responses import FileResponse
3
  import os
4
- from .config import Config
5
  from .main import ComicPanelExtractor
6
  import traceback
7
  from pathlib import Path
@@ -9,8 +9,10 @@ import shutil
9
  import time
10
  import mimetypes
11
 
 
 
12
  base_output_folder = "api_outputs"
13
- output_folder = os.path.join(Config.current_path, base_output_folder)
14
 
15
  app = APIRouter()
16
 
 
1
  from fastapi import APIRouter, File, UploadFile, HTTPException
2
  from fastapi.responses import FileResponse
3
  import os
4
+ from .config import load_config
5
  from .main import ComicPanelExtractor
6
  import traceback
7
  from pathlib import Path
 
9
  import time
10
  import mimetypes
11
 
12
+ config = load_config()
13
+
14
  base_output_folder = "api_outputs"
15
+ output_folder = os.path.join(config.current_path, base_output_folder)
16
 
17
  app = APIRouter()
18
 
comic_panel_extractor/inference.py CHANGED
@@ -2,7 +2,9 @@
2
  from .yolo_manager import YOLOManager
3
  from .utils import get_abs_path, get_image_paths
4
  import os
5
- from .config import Config
 
 
6
 
7
  def run_inference(weights_path: str, images_dirs, output_dir: str = 'temp_dir') -> None:
8
  """
@@ -41,7 +43,7 @@ def run_inference(weights_path: str, images_dirs, output_dir: str = 'temp_dir')
41
 
42
  def main():
43
  """Main inference function."""
44
- weights_path = Config.yolo_trained_model_path
45
  images_dirs = [
46
  './dataset/images/train',
47
  './dataset/images/val',
@@ -52,10 +54,10 @@ def main():
52
 
53
  def annotate_all_image():
54
  with YOLOManager() as yolo_manager:
55
- weights_path = Config.yolo_trained_model_path
56
  yolo_manager.load_model(weights_path)
57
- IMAGE_ROOT = os.path.join(Config.current_path, "dataset/images")
58
- IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels")
59
  for root, _, files in os.walk(IMAGE_ROOT):
60
  for file in sorted(files):
61
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
 
2
  from .yolo_manager import YOLOManager
3
  from .utils import get_abs_path, get_image_paths
4
  import os
5
+ from .config import load_config
6
+
7
+ config = load_config()
8
 
9
  def run_inference(weights_path: str, images_dirs, output_dir: str = 'temp_dir') -> None:
10
  """
 
43
 
44
  def main():
45
  """Main inference function."""
46
+ weights_path = config.yolo_trained_model_path
47
  images_dirs = [
48
  './dataset/images/train',
49
  './dataset/images/val',
 
54
 
55
  def annotate_all_image():
56
  with YOLOManager() as yolo_manager:
57
+ weights_path = config.yolo_trained_model_path
58
  yolo_manager.load_model(weights_path)
59
+ IMAGE_ROOT = os.path.join(config.current_path, "dataset/images")
60
+ IMAGE_LABEL_ROOT = os.path.join(config.current_path, "image_labels")
61
  for root, _, files in os.walk(IMAGE_ROOT):
62
  for file in sorted(files):
63
  if file.lower().endswith((".jpg", ".jpeg", ".png")):
comic_panel_extractor/server.py CHANGED
@@ -4,7 +4,7 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from .extractor_server import app as extractor_app, delete_folder_if_old_or_empty, output_folder
5
  from .annorator_server import app as annotator_app
6
  import os, json
7
- from .config import Config
8
 
9
  from fastapi import Request
10
  from fastapi.responses import HTMLResponse
@@ -13,9 +13,10 @@ import os
13
  from jinja2 import Environment, FileSystemLoader, select_autoescape
14
 
15
  fast_api = FastAPI()
 
16
 
17
  # Mount static files ONCE
18
- static_folder = os.path.join(Config.current_path, "static")
19
  fast_api.mount("/static", StaticFiles(directory=static_folder), name="static")
20
 
21
  fast_api.include_router(extractor_app)
 
4
  from .extractor_server import app as extractor_app, delete_folder_if_old_or_empty, output_folder
5
  from .annorator_server import app as annotator_app
6
  import os, json
7
+ from .config import load_config
8
 
9
  from fastapi import Request
10
  from fastapi.responses import HTMLResponse
 
13
  from jinja2 import Environment, FileSystemLoader, select_autoescape
14
 
15
  fast_api = FastAPI()
16
+ config = load_config()
17
 
18
  # Mount static files ONCE
19
+ static_folder = os.path.join(config.current_path, "static")
20
  fast_api.mount("/static", StaticFiles(directory=static_folder), name="static")
21
 
22
  fast_api.include_router(extractor_app)
comic_panel_extractor/static/annotator.html CHANGED
@@ -790,9 +790,14 @@
790
  <!-- Quick Actions -->
791
  <div class="sidebar-section">
792
  <div class="section-title">Actions</div>
793
- <button class="btn btn-primary btn-sm trainBtn" id="trainBtn">
794
- Train
795
- </button>
 
 
 
 
 
796
  <button class="btn btn-primary btn-sm trainBtn" id="deployModalBtn">
797
  Deploy Model
798
  </button>
@@ -860,6 +865,52 @@
860
  <!-- Alerts Container -->
861
  <div class="alerts" id="alerts"></div>
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
  <div id="outputModal" class="modal">
865
  <div class="modal-content" style="max-width: none; margin: auto;">
@@ -953,6 +1004,7 @@
953
  init() {
954
  this.setupEventListeners();
955
  this.loadImages();
 
956
  }
957
 
958
  setupEventListeners() {
@@ -1032,7 +1084,7 @@
1032
  document.getElementById('trainBtn').addEventListener('click', async (e) => {
1033
  try {
1034
  this.openXterm();
1035
- const response = await fetch('/api/annotate/train?recreate_dataset=true');
1036
 
1037
  if (!response.ok) {
1038
  throw new Error(`Server error: ${response.status}`);
@@ -1076,8 +1128,124 @@
1076
  }
1077
  }
1078
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1079
  }
1080
 
 
1081
  updateCanvasCursor() {
1082
  if (this.annotationMode === 'segmentation') {
1083
  this.canvas.style.cursor = 'crosshair';
 
790
  <!-- Quick Actions -->
791
  <div class="sidebar-section">
792
  <div class="section-title">Actions</div>
793
+ <div style="display: flex; gap: 8px;">
794
+ <button class="btn btn-primary btn-sm trainBtn" id="trainBtn" style="flex: 1;">
795
+ Train
796
+ </button>
797
+ <button class="btn btn-ghost btn-sm" id="settingsBtn" style="padding: 6px 8px; font-size: 11px;">
798
+ βš™οΈ
799
+ </button>
800
+ </div>
801
  <button class="btn btn-primary btn-sm trainBtn" id="deployModalBtn">
802
  Deploy Model
803
  </button>
 
865
  <!-- Alerts Container -->
866
  <div class="alerts" id="alerts"></div>
867
 
868
+ <!-- Settings Modal -->
869
+ <div id="settingsModal" class="modal">
870
+ <div class="modal-content">
871
+ <span class="close" id="closeSettingsModal">Γ—</span>
872
+ <h2>Train Settings</h2>
873
+ <div class="form-field">
874
+ <label class="form-label">Epoch</label>
875
+ <input type="number" class="form-input" id="epoch" value="10" min="1">
876
+ </div>
877
+ <div class="form-field">
878
+ <label class="form-label">Batch Size</label>
879
+ <input type="number" class="form-input" id="batch" value="8" min="1">
880
+ </div>
881
+ <div class="form-field">
882
+ <label class="form-label">Image Size</label>
883
+ <input type="number" class="form-input" id="imgsz" value="640" min="1">
884
+ </div>
885
+ <div class="form-field">
886
+ <label class="form-label">
887
+ <input type="checkbox" id="recreateDataset" checked>
888
+ Recreate Dataset
889
+ </label>
890
+ </div>
891
+ <div class="form-field">
892
+ <label class="form-label">
893
+ <input type="checkbox" id="resumeTrain" checked>
894
+ Resume Train
895
+ </label>
896
+ </div>
897
+ <button class="btn btn-primary" id="saveSettingsBtn">Save</button>
898
+ </div>
899
+ </div>
900
+
901
+ <!-- Deploy Modal -->
902
+ <div id="deployModal" class="modal">
903
+ <div class="modal-content">
904
+ <span class="close" id="closeDeployModal">Γ—</span>
905
+ <h2>Deploy Model</h2>
906
+ <div class="form-field">
907
+ <label class="form-label">App Name</label>
908
+ <input type="text" class="form-input" id="appName" placeholder="Enter a unique app name">
909
+ </div>
910
+ <button class="btn btn-primary" id="deployBtn">Deploy</button>
911
+ </div>
912
+ </div>
913
+
914
 
915
  <div id="outputModal" class="modal">
916
  <div class="modal-content" style="max-width: none; margin: auto;">
 
1004
  init() {
1005
  this.setupEventListeners();
1006
  this.loadImages();
1007
+ this.loadTrainConfig();
1008
  }
1009
 
1010
  setupEventListeners() {
 
1084
  document.getElementById('trainBtn').addEventListener('click', async (e) => {
1085
  try {
1086
  this.openXterm();
1087
+ const response = await fetch('/api/annotate/train');
1088
 
1089
  if (!response.ok) {
1090
  throw new Error(`Server error: ${response.status}`);
 
1128
  }
1129
  }
1130
  });
1131
+
1132
+ // Settings Modal
1133
+ document.getElementById('settingsBtn').addEventListener('click', () => {
1134
+ document.getElementById('settingsModal').style.display = 'block';
1135
+ });
1136
+
1137
+ document.getElementById('closeSettingsModal').addEventListener('click', () => {
1138
+ document.getElementById('settingsModal').style.display = 'none';
1139
+ });
1140
+
1141
+ document.getElementById('saveSettingsBtn').addEventListener('click', async () => {
1142
+ const newSettings = {
1143
+ epoch: parseInt(document.getElementById('epoch').value) || 200,
1144
+ batch: parseInt(document.getElementById('batch').value) || 10,
1145
+ imgsz: parseInt(document.getElementById('imgsz').value) || 640,
1146
+ recreate_dataset: document.getElementById('recreateDataset').checked,
1147
+ resume_train: document.getElementById('resumeTrain').checked
1148
+ };
1149
+
1150
+ try {
1151
+ const response = await fetch('/api/annotate/train/config', {
1152
+ method: 'POST',
1153
+ headers: {
1154
+ 'Content-Type': 'application/json',
1155
+ },
1156
+ body: JSON.stringify(newSettings)
1157
+ });
1158
+
1159
+ if (!response.ok) {
1160
+ const errorData = await response.json();
1161
+ throw new Error(errorData.message || 'Failed to save settings');
1162
+ }
1163
+
1164
+ const result = await response.json();
1165
+ this.trainSettings = newSettings;
1166
+ document.getElementById('settingsModal').style.display = 'none';
1167
+ this.showAlert(result.message || 'Settings saved successfully', 'success');
1168
+
1169
+ } catch (error) {
1170
+ console.error('Error saving settings:', error);
1171
+ this.showAlert('Error: ' + error.message, 'error');
1172
+ }
1173
+ });
1174
+
1175
+
1176
+ // Deploy Modal
1177
+ document.getElementById('deployModalBtn').addEventListener('click', () => {
1178
+ document.getElementById('deployModal').style.display = 'block';
1179
+ });
1180
+
1181
+ document.getElementById('closeDeployModal').addEventListener('click', () => {
1182
+ document.getElementById('deployModal').style.display = 'none';
1183
+ });
1184
+
1185
+ document.getElementById('deployBtn').addEventListener('click', async () => {
1186
+ const appName = document.getElementById('appName').value;
1187
+ if (!appName) {
1188
+ this.showAlert('Please enter an app name', 'error');
1189
+ return;
1190
+ }
1191
+
1192
+ try {
1193
+ this.openXterm();
1194
+ const response = await fetch(`/api/annotate/deploy?app_name=${appName}`);
1195
+ if (!response.ok) {
1196
+ throw new Error(`Server error: ${response.status}`);
1197
+ }
1198
+ const result = await response.json();
1199
+ this.showAlert(result.message, 'success');
1200
+ } catch (error) {
1201
+ if (term) {
1202
+ term.write(`\x1b[31m[Error starting command: ${error.message}]\x1b[0m\r\n`);
1203
+ } else {
1204
+ this.showAlert('Error starting command: ' + error.message, 'error');
1205
+ }
1206
+ }
1207
+ document.getElementById('deployModal').style.display = 'none';
1208
+ });
1209
+
1210
+ // Reset Model
1211
+ document.getElementById('resetModalBtn').addEventListener('click', async () => {
1212
+ if (confirm('Are you sure you want to reset the model? This action cannot be undone.')) {
1213
+ try {
1214
+ this.openXterm();
1215
+ const response = await fetch('/api/annotate/model_reset', { method: 'POST' });
1216
+ if (!response.ok) {
1217
+ throw new Error(`Server error: ${response.status}`);
1218
+ }
1219
+ const result = await response.json();
1220
+ this.showAlert(result.message, 'success');
1221
+ } catch (error) {
1222
+ if (term) {
1223
+ term.write(`\x1b[31m[Error starting command: ${error.message}]\x1b[0m\r\n`);
1224
+ } else {
1225
+ this.showAlert('Error starting command: ' + error.message, 'error');
1226
+ }
1227
+ }
1228
+ }
1229
+ });
1230
+
1231
+ }
1232
+
1233
+ async loadTrainConfig() {
1234
+ try {
1235
+ const response = await fetch('/api/annotate/train/config');
1236
+ const config = await response.json();
1237
+ this.trainSettings = config;
1238
+ document.getElementById('epoch').value = config.epoch;
1239
+ document.getElementById('batch').value = config.batch;
1240
+ document.getElementById('imgsz').value = config.imgsz;
1241
+ document.getElementById('recreateDataset').checked = config.recreate_dataset;
1242
+ document.getElementById('resumeTrain').checked = config.resume_train;
1243
+ } catch (error) {
1244
+ this.showAlert('Error loading training config: ' + error.message, 'error');
1245
+ }
1246
  }
1247
 
1248
+
1249
  updateCanvasCursor() {
1250
  if (this.annotationMode === 'segmentation') {
1251
  this.canvas.style.cursor = 'crosshair';
comic_panel_extractor/train.py CHANGED
@@ -2,12 +2,14 @@
2
  from .yolo_manager import YOLOManager
3
  from .utils import get_abs_path, backup_file
4
  import os
5
- from .config import Config
6
  import yaml
7
  import os
8
  from pathlib import Path
9
  import shutil
10
 
 
 
11
  def convert_box_to_polygon(label_file: Path):
12
  """
13
  Converts YOLO box-format labels (class xc yc w h) to YOLO polygon-format labels
@@ -138,7 +140,7 @@ def create_filtered_yaml(output_filtered_dataset_path, filtered_counts):
138
  Create the YAML file for the filtered dataset
139
  """
140
  output_path = Path(output_filtered_dataset_path)
141
- yaml_path = f'{Config.current_path}/filtered_comic.yaml'
142
 
143
  # Create YAML structure
144
  yaml_data = {
@@ -167,17 +169,17 @@ def main():
167
  yolo_manager = YOLOManager()
168
 
169
  # Configuration
170
- data_yaml_path = f'{Config.current_path}/filtered_comic.yaml'
171
 
172
  if not os.path.isfile(data_yaml_path):
173
  raise FileNotFoundError(f"❌ Dataset YAML not found: {data_yaml_path}")
174
 
175
- print(f"🎯 Training model: {Config.YOLO_MODEL_NAME}")
176
 
177
  # Train model
178
  model = yolo_manager.train(
179
  data_yaml_path=data_yaml_path,
180
- run_name=Config.YOLO_MODEL_NAME
181
  )
182
 
183
  # Validate model
@@ -185,7 +187,7 @@ def main():
185
 
186
  # Backup best weights
187
  weights_path = yolo_manager.get_best_weights_path()
188
- backup_path = Config.yolo_trained_model_path
189
  backup_file(weights_path, backup_path)
190
 
191
  print("πŸŽ‰ Training completed successfully!")
 
2
  from .yolo_manager import YOLOManager
3
  from .utils import get_abs_path, backup_file
4
  import os
5
+ from .config import load_config
6
  import yaml
7
  import os
8
  from pathlib import Path
9
  import shutil
10
 
11
+ config = load_config()
12
+
13
  def convert_box_to_polygon(label_file: Path):
14
  """
15
  Converts YOLO box-format labels (class xc yc w h) to YOLO polygon-format labels
 
140
  Create the YAML file for the filtered dataset
141
  """
142
  output_path = Path(output_filtered_dataset_path)
143
+ yaml_path = f'{config.current_path}/filtered_comic.yaml'
144
 
145
  # Create YAML structure
146
  yaml_data = {
 
169
  yolo_manager = YOLOManager()
170
 
171
  # Configuration
172
+ data_yaml_path = f'{config.current_path}/filtered_comic.yaml'
173
 
174
  if not os.path.isfile(data_yaml_path):
175
  raise FileNotFoundError(f"❌ Dataset YAML not found: {data_yaml_path}")
176
 
177
+ print(f"🎯 Training model: {config.YOLO_MODEL_NAME}")
178
 
179
  # Train model
180
  model = yolo_manager.train(
181
  data_yaml_path=data_yaml_path,
182
+ run_name=config.YOLO_MODEL_NAME
183
  )
184
 
185
  # Validate model
 
187
 
188
  # Backup best weights
189
  weights_path = yolo_manager.get_best_weights_path()
190
+ backup_path = config.yolo_trained_model_path
191
  backup_file(weights_path, backup_path)
192
 
193
  print("πŸŽ‰ Training completed successfully!")
comic_panel_extractor/utils.py CHANGED
@@ -7,9 +7,11 @@ import os
7
  import shutil
8
  from glob import glob
9
  from typing import List, Union
10
- from .config import Config
11
  from shapely.geometry import Polygon
12
 
 
 
13
  def remove_duplicate_boxes(boxes, compare_single=None, iou_threshold=0.7):
14
  """
15
  Removes duplicate or highly overlapping boxes, keeping the larger one.
@@ -508,7 +510,7 @@ def get_image_paths(directories: Union[str, List[str]]) -> List[str]:
508
  continue
509
 
510
  # Support multiple image extensions
511
- for ext in Config.SUPPORTED_EXTENSIONS:
512
  pattern = os.path.join(abs_dir, f'*.{ext}')
513
  images = sorted(glob(pattern))
514
  all_images.extend(images)
 
7
  import shutil
8
  from glob import glob
9
  from typing import List, Union
10
+ from .config import load_config
11
  from shapely.geometry import Polygon
12
 
13
+ config = load_config()
14
+
15
  def remove_duplicate_boxes(boxes, compare_single=None, iou_threshold=0.7):
16
  """
17
  Removes duplicate or highly overlapping boxes, keeping the larger one.
 
510
  continue
511
 
512
  # Support multiple image extensions
513
+ for ext in config.SUPPORTED_EXTENSIONS:
514
  pattern = os.path.join(abs_dir, f'*.{ext}')
515
  images = sorted(glob(pattern))
516
  all_images.extend(images)
comic_panel_extractor/yolo_manager.py CHANGED
@@ -4,6 +4,8 @@ import shutil
4
  from glob import glob
5
  from typing import List, Union
6
  from . import utils
 
 
7
 
8
  os.environ["TORCH_USE_CUDA_DSA"] = "1"
9
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@@ -33,7 +35,7 @@ def get_image_paths(directories: Union[str, List[str]]) -> List[str]:
33
  continue
34
 
35
  # Support multiple image extensions
36
- for ext in Config.SUPPORTED_EXTENSIONS:
37
  pattern = os.path.join(abs_dir, f'*.{ext}')
38
  images = sorted(glob(pattern))
39
  all_images.extend(images)
@@ -62,7 +64,7 @@ class YOLOManager:
62
  """Manages YOLO model training and inference operations."""
63
 
64
  def __init__(self, model_name: Optional[str] = None):
65
- self.model_name = model_name or Config.YOLO_MODEL_NAME
66
  self.model = None
67
 
68
  def load_model(self, weights_path: Optional[str] = None) -> YOLO:
@@ -71,15 +73,15 @@ class YOLOManager:
71
  print(f"πŸ“¦ Loading model from: {weights_path}")
72
  self.model = YOLO(weights_path)
73
  else:
74
- print(f"✨ Loading pretrained model '{Config.yolo_base_model_path}'")
75
- self.model = YOLO(f"{Config.yolo_base_model_path}")
76
  return self.model
77
 
78
  def train(self,
79
  data_yaml_path: str,
80
  run_name: Optional[str] = None,
81
  device: int = 0,
82
- resume: bool = Config.RESUME_TRAIN,
83
  **kwargs) -> YOLO:
84
  """
85
  Train YOLO model with given parameters.
@@ -92,7 +94,7 @@ class YOLOManager:
92
  **kwargs: Additional training parameters
93
  """
94
  run_name = run_name or self.model_name
95
- checkpoint_path = f"{Config.current_path}/runs/detect/{run_name}/weights/last.pt"
96
 
97
  # Check for existing checkpoint
98
  if resume and os.path.isfile(checkpoint_path):
@@ -106,13 +108,13 @@ class YOLOManager:
106
  # Default training parameters
107
  train_params = {
108
  'data': data_yaml_path,
109
- 'imgsz': Config.DEFAULT_IMAGE_SIZE,
110
- 'epochs': Config.EPOCH,
111
- 'batch': Config.BATCH,
112
  'name': run_name,
113
  'device': device,
114
  'cache': True,
115
- 'project': f'{Config.current_path}/runs/detect',
116
  'exist_ok': True,
117
  'pose': False,
118
  'resume': resume_flag,
@@ -139,7 +141,7 @@ class YOLOManager:
139
  def get_best_weights_path(self, run_name: Optional[str] = None) -> str:
140
  """Get path to best trained weights."""
141
  run_name = run_name or self.model_name
142
- weights_path = os.path.join(Config.current_path, 'runs', 'detect', run_name, 'weights', 'best.pt')
143
 
144
  if not os.path.isfile(weights_path):
145
  raise FileNotFoundError(f"❌ Trained weights not found at: {weights_path}")
@@ -163,7 +165,7 @@ class YOLOManager:
163
  if not image_paths:
164
  raise ValueError("❌ No images provided for annotation.")
165
 
166
- image_size = image_size or Config.DEFAULT_IMAGE_SIZE
167
  # clean_directory(output_dir)
168
  total_images = len(image_paths)
169
  print(f"🎨 Annotating {total_images} images and saving labels...")
 
4
  from glob import glob
5
  from typing import List, Union
6
  from . import utils
7
+ from .config import load_config
8
+ config = load_config()
9
 
10
  os.environ["TORCH_USE_CUDA_DSA"] = "1"
11
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
35
  continue
36
 
37
  # Support multiple image extensions
38
+ for ext in config.SUPPORTED_EXTENSIONS:
39
  pattern = os.path.join(abs_dir, f'*.{ext}')
40
  images = sorted(glob(pattern))
41
  all_images.extend(images)
 
64
  """Manages YOLO model training and inference operations."""
65
 
66
  def __init__(self, model_name: Optional[str] = None):
67
+ self.model_name = model_name or config.YOLO_MODEL_NAME
68
  self.model = None
69
 
70
  def load_model(self, weights_path: Optional[str] = None) -> YOLO:
 
73
  print(f"πŸ“¦ Loading model from: {weights_path}")
74
  self.model = YOLO(weights_path)
75
  else:
76
+ print(f"✨ Loading pretrained model '{config.yolo_base_model_path}'")
77
+ self.model = YOLO(f"{config.yolo_base_model_path}")
78
  return self.model
79
 
80
  def train(self,
81
  data_yaml_path: str,
82
  run_name: Optional[str] = None,
83
  device: int = 0,
84
+ resume: bool = config.RESUME_TRAIN,
85
  **kwargs) -> YOLO:
86
  """
87
  Train YOLO model with given parameters.
 
94
  **kwargs: Additional training parameters
95
  """
96
  run_name = run_name or self.model_name
97
+ checkpoint_path = f"{config.current_path}/runs/detect/{run_name}/weights/last.pt"
98
 
99
  # Check for existing checkpoint
100
  if resume and os.path.isfile(checkpoint_path):
 
108
  # Default training parameters
109
  train_params = {
110
  'data': data_yaml_path,
111
+ 'imgsz': config.DEFAULT_IMAGE_SIZE,
112
+ 'epochs': config.EPOCH,
113
+ 'batch': config.BATCH,
114
  'name': run_name,
115
  'device': device,
116
  'cache': True,
117
+ 'project': f'{config.current_path}/runs/detect',
118
  'exist_ok': True,
119
  'pose': False,
120
  'resume': resume_flag,
 
141
  def get_best_weights_path(self, run_name: Optional[str] = None) -> str:
142
  """Get path to best trained weights."""
143
  run_name = run_name or self.model_name
144
+ weights_path = os.path.join(config.current_path, 'runs', 'detect', run_name, 'weights', 'best.pt')
145
 
146
  if not os.path.isfile(weights_path):
147
  raise FileNotFoundError(f"❌ Trained weights not found at: {weights_path}")
 
165
  if not image_paths:
166
  raise ValueError("❌ No images provided for annotation.")
167
 
168
+ image_size = image_size or config.DEFAULT_IMAGE_SIZE
169
  # clean_directory(output_dir)
170
  total_images = len(image_paths)
171
  print(f"🎨 Annotating {total_images} images and saving labels...")