AkJeond commited on
Commit
5b217af
·
1 Parent(s): 22d59b7

feat(backend): 문서 타입별 모델 선택 및 프리로드 추가

Browse files
app/main.py CHANGED
@@ -26,6 +26,7 @@ from sqlalchemy.orm import Session
26
  from .database import engine, get_db, init_db, test_connection
27
  from . import models
28
  from .routers import analysis, downloads, pages, projects
 
29
 
30
  # 환경 변수 로드
31
  load_dotenv()
@@ -126,6 +127,19 @@ async def startup_event():
126
  print("✅ Database tables initialized")
127
  except Exception as e:
128
  print(f"⚠️ Table initialization warning: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  print("=" * 60)
131
  print("✅ SmartEyeSsen Backend Ready!")
 
26
  from .database import engine, get_db, init_db, test_connection
27
  from . import models
28
  from .routers import analysis, downloads, pages, projects
29
+ from .services.model_registry import model_registry
30
 
31
  # 환경 변수 로드
32
  load_dotenv()
 
127
  print("✅ Database tables initialized")
128
  except Exception as e:
129
  print(f"⚠️ Table initialization warning: {e}")
130
+
131
+ preload_env = os.getenv("MODEL_PRELOAD", "SmartEyeSsen")
132
+ preload_targets = [
133
+ name.strip()
134
+ for name in preload_env.split(",")
135
+ if name.strip()
136
+ ]
137
+ if preload_targets:
138
+ try:
139
+ model_registry.preload(preload_targets)
140
+ print(f"🧠 Preloaded models: {', '.join(preload_targets)}")
141
+ except Exception as e:
142
+ print(f"⚠️ Model preload failed: {e}")
143
 
144
  print("=" * 60)
145
  print("✅ SmartEyeSsen Backend Ready!")
app/routers/analysis.py CHANGED
@@ -15,6 +15,8 @@ from ..services.batch_analysis import (
15
  analyze_project_batch_async_parallel,
16
  _get_analysis_service,
17
  _process_single_page_async,
 
 
18
  )
19
  from ..services.formatter import TextFormatter
20
 
@@ -37,12 +39,14 @@ class ProjectAnalysisRequest(BaseModel):
37
  api_key: Optional[str] = None
38
  use_parallel: bool = True # False → True (병렬 처리 기본값)
39
  max_concurrent_pages: int = 8 # 4 → 8 (성능 최적화)
 
40
 
41
 
42
  class PageAnalysisRequest(BaseModel):
43
  """단일 페이지 비동기 분석 요청"""
44
  use_ai_descriptions: bool = True
45
  api_key: Optional[str] = None
 
46
 
47
 
48
  @router.post(
@@ -71,9 +75,14 @@ async def analyze_project(
71
  - 모델: 싱글톤 패턴으로 메모리 효율적 (중복 로드 방지)
72
  - 권장: 모든 환경 (CPU 4코어 이상, RAM 4GB+)
73
  """
74
- project_exists = db.query(Project.project_id).filter(Project.project_id == project_id).scalar()
75
- if not project_exists:
76
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="프로젝트를 찾을 수 없습니다.")
 
 
 
 
 
77
 
78
  if payload.use_parallel:
79
  logger.info(f"병렬 분석 시작: project_id={project_id}, max_concurrent={payload.max_concurrent_pages}")
@@ -83,6 +92,7 @@ async def analyze_project(
83
  use_ai_descriptions=payload.use_ai_descriptions,
84
  api_key=payload.api_key,
85
  max_concurrent_pages=payload.max_concurrent_pages,
 
86
  )
87
  else:
88
  logger.info(f"순차 분석 시작: project_id={project_id}")
@@ -91,6 +101,7 @@ async def analyze_project(
91
  project_id=project_id,
92
  use_ai_descriptions=payload.use_ai_descriptions,
93
  api_key=payload.api_key,
 
94
  )
95
 
96
  return analysis_result
@@ -133,6 +144,11 @@ def analyze_page_async(
133
  status_code=status.HTTP_404_NOT_FOUND,
134
  detail=f"페이지 ID {page_id}를 찾을 수 없습니다."
135
  )
 
 
 
 
 
136
 
137
  # 작업 ID 생성
138
  job_id = str(uuid.uuid4())
@@ -142,6 +158,7 @@ def analyze_page_async(
142
  "page_id": page_id,
143
  "page_number": page.page_number,
144
  "project_id": page.project_id,
 
145
  "result": None,
146
  "error": None,
147
  "progress": "작업 대기 중...",
@@ -156,6 +173,7 @@ def analyze_page_async(
156
  page_id=page_id,
157
  use_ai_descriptions=payload.use_ai_descriptions,
158
  api_key=payload.api_key,
 
159
  )
160
 
161
  return {
@@ -196,6 +214,7 @@ async def _run_async_page_analysis(
196
  page_id: int,
197
  use_ai_descriptions: bool,
198
  api_key: Optional[str],
 
199
  ) -> None:
200
  """
201
  백그라운드에서 실행되는 단일 페이지 비동기 분석 작업
@@ -227,7 +246,8 @@ async def _run_async_page_analysis(
227
  raise ValueError(f"프로젝트 ID {page.project_id}를 찾을 수 없습니다.")
228
 
229
  # AnalysisService 및 TextFormatter 초기화
230
- analysis_service = _get_analysis_service()
 
231
  formatter = TextFormatter(
232
  doc_type_id=project.doc_type_id,
233
  db=db,
 
15
  analyze_project_batch_async_parallel,
16
  _get_analysis_service,
17
  _process_single_page_async,
18
+ is_supported_model,
19
+ resolve_model_choice,
20
  )
21
  from ..services.formatter import TextFormatter
22
 
 
39
  api_key: Optional[str] = None
40
  use_parallel: bool = True # False → True (병렬 처리 기본값)
41
  max_concurrent_pages: int = 8 # 4 → 8 (성능 최적화)
42
+ analysis_model: Optional[str] = None
43
 
44
 
45
  class PageAnalysisRequest(BaseModel):
46
  """단일 페이지 비동기 분석 요청"""
47
  use_ai_descriptions: bool = True
48
  api_key: Optional[str] = None
49
+ analysis_model: Optional[str] = None
50
 
51
 
52
  @router.post(
 
75
  - 모델: 싱글톤 패턴으로 메모리 효율적 (중복 로드 방지)
76
  - 권장: 모든 환경 (CPU 4코어 이상, RAM 4GB+)
77
  """
78
+ project = db.query(Project).filter(Project.project_id == project_id).first()
79
+ if not project:
80
  raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="프로젝트를 찾을 수 없습니다.")
81
+ if payload.analysis_model and not is_supported_model(payload.analysis_model):
82
+ raise HTTPException(
83
+ status_code=status.HTTP_400_BAD_REQUEST,
84
+ detail=f"지원하지 않는 모델입니다: {payload.analysis_model}",
85
+ )
86
 
87
  if payload.use_parallel:
88
  logger.info(f"병렬 분석 시작: project_id={project_id}, max_concurrent={payload.max_concurrent_pages}")
 
92
  use_ai_descriptions=payload.use_ai_descriptions,
93
  api_key=payload.api_key,
94
  max_concurrent_pages=payload.max_concurrent_pages,
95
+ analysis_model=payload.analysis_model or None,
96
  )
97
  else:
98
  logger.info(f"순차 분석 시작: project_id={project_id}")
 
101
  project_id=project_id,
102
  use_ai_descriptions=payload.use_ai_descriptions,
103
  api_key=payload.api_key,
104
+ analysis_model=payload.analysis_model or None,
105
  )
106
 
107
  return analysis_result
 
144
  status_code=status.HTTP_404_NOT_FOUND,
145
  detail=f"페이지 ID {page_id}를 찾을 수 없습니다."
146
  )
147
+ if payload.analysis_model and not is_supported_model(payload.analysis_model):
148
+ raise HTTPException(
149
+ status_code=status.HTTP_400_BAD_REQUEST,
150
+ detail=f"지원하지 않는 모델입니다: {payload.analysis_model}",
151
+ )
152
 
153
  # 작업 ID 생성
154
  job_id = str(uuid.uuid4())
 
158
  "page_id": page_id,
159
  "page_number": page.page_number,
160
  "project_id": page.project_id,
161
+ "analysis_model": payload.analysis_model,
162
  "result": None,
163
  "error": None,
164
  "progress": "작업 대기 중...",
 
173
  page_id=page_id,
174
  use_ai_descriptions=payload.use_ai_descriptions,
175
  api_key=payload.api_key,
176
+ analysis_model=payload.analysis_model,
177
  )
178
 
179
  return {
 
214
  page_id: int,
215
  use_ai_descriptions: bool,
216
  api_key: Optional[str],
217
+ analysis_model: Optional[str],
218
  ) -> None:
219
  """
220
  백그라운드에서 실행되는 단일 페이지 비동기 분석 작업
 
246
  raise ValueError(f"프로젝트 ID {page.project_id}를 찾을 수 없습니다.")
247
 
248
  # AnalysisService 및 TextFormatter 초기화
249
+ model_choice = resolve_model_choice(project.doc_type_id, analysis_model)
250
+ analysis_service = _get_analysis_service(model_choice)
251
  formatter = TextFormatter(
252
  doc_type_id=project.doc_type_id,
253
  db=db,
app/services/analysis_service.py CHANGED
@@ -27,12 +27,12 @@ import openai
27
  import pytesseract
28
  import torch
29
  from PIL import Image
30
- from huggingface_hub import hf_hub_download
31
  from loguru import logger
32
  from openai import AsyncOpenAI
33
  from sqlalchemy.orm import Session
34
 
35
  from .. import models
 
36
 
37
  # --- 신규: 이미지 설명을 위한 프롬프트 템플릿 추가 ---
38
  figure_prompt = """
@@ -239,77 +239,34 @@ class AnalysisService:
239
  model_choice: 사용할 모델 선택 (기본값: "SmartEyeSsen")
240
  auto_load: True이면 초기화 시 자동으로 모델 로드 (기본값: False, 하위 호환성 유지)
241
  """
242
- self.model = None
243
  self.device = device
244
  self.model_choice = model_choice
 
 
245
  self._model_loaded = False
246
 
247
  # 자동 로드 옵션이 활성화된 경우 즉시 모델 로드
248
  if auto_load:
249
  self._ensure_model_loaded()
250
 
251
- def download_model(self, model_choice="SmartEyeSsen"):
252
- """모델 다운로드 (기존과 동일)"""
253
- models = {
254
- "doclaynet_docsynth": {
255
- "repo_id": "juliozhao/DocLayout-YOLO-DocLayNet-Docsynth300K_pretrained",
256
- "filename": "doclayout_yolo_doclaynet_imgsz1120_docsynth_pretrain.pt",
257
- },
258
- "docstructbench": {
259
- "repo_id": "juliozhao/DocLayout-YOLO-DocStructBench",
260
- "filename": "doclayout_yolo_docstructbench_imgsz1024.pt",
261
- },
262
- "docsynth300k": {
263
- "repo_id": "juliozhao/DocLayout-YOLO-DocSynth300K-pretrain",
264
- "filename": "doclayout_yolo_docsynth300k_imgsz1600.pt",
265
- },
266
- "SmartEyeSsen": {"repo_id": "AkJeond/SmartEye", "filename": "best.pt"},
267
- }
268
- selected_model = models.get(model_choice, models["SmartEyeSsen"])
269
- try:
270
- logger.info(f"모델 다운로드 중: {selected_model['repo_id']}")
271
- filepath = hf_hub_download(
272
- repo_id=selected_model["repo_id"], filename=selected_model["filename"]
273
- )
274
- logger.info(f"모델 다운로드 완료: {filepath}")
275
- return filepath
276
- except Exception as e:
277
- logger.error(f"모델 다운로드 실패: {e}")
278
- raise
279
-
280
- def load_model(self, model_path):
281
- """모델 로드 (기존과 동일)"""
282
- try:
283
- try:
284
- from doclayout_yolo import YOLOv10
285
- except ImportError:
286
- logger.error("DocLayout-YOLO가 설치되지 않았습니다.")
287
- return False
288
- logger.info("모델 로드 중...")
289
- self.model = YOLOv10(model_path, task="predict")
290
- self.model.to(self.device)
291
- if hasattr(self.model, "training"):
292
- self.model.training = False
293
- logger.info("모델 로드 완료!")
294
- return True
295
- except Exception as e:
296
- logger.error(f"모델 로드 실패: {e}")
297
- return False
298
-
299
- def _ensure_model_loaded(self):
300
  """
301
  Lazy Loading: 모델이 로드되지 않았으면 자동으로 로드
302
  (다중 페이지 처리 시 모델을 한 번만 로드하도록 최적화)
303
  """
304
- if self._model_loaded and self.model is not None:
305
- return # 이미 로드됨
306
-
307
- logger.info(f"모델 자동 로드 시작 (선택: {self.model_choice})...")
308
- model_path = self.download_model(self.model_choice)
309
- if not self.load_model(model_path):
310
- raise RuntimeError(f"모델 로드 실패: {self.model_choice}")
 
 
 
 
311
  self._model_loaded = True
312
- logger.info("모델 자동 로드 완료!")
313
 
314
  def analyze_layout(
315
  self,
@@ -341,27 +298,24 @@ class AnalysisService:
341
  self._model_loaded = False
342
 
343
  # Lazy Loading: 모델이 없으면 자동 로드
344
- self._ensure_model_loaded()
 
 
345
 
346
  logger.info("레이아웃 분석 시작...")
347
  temp_path = "temp_image.jpg"
348
  cv2.imwrite(temp_path, image)
349
 
350
- if active_model == "SmartEyeSsen":
351
- imgsz, conf = 1024, 0.25
352
- elif active_model == "docsynth300k":
353
- imgsz, conf = 1600, 0.15
354
- else:
355
- imgsz, conf = 1024, 0.25
356
 
357
- results = self.model.predict(
358
  temp_path, imgsz=imgsz, conf=conf, iou=0.45, device=self.device
359
  )
360
 
361
  boxes = results[0].boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2]
362
  classes = results[0].boxes.cls.cpu().numpy()
363
  confs = results[0].boxes.conf.cpu().numpy()
364
- class_names = self.model.names # 클래스 ID → 이름
365
 
366
  detection_records: List[Dict[str, float]] = []
367
 
 
27
  import pytesseract
28
  import torch
29
  from PIL import Image
 
30
  from loguru import logger
31
  from openai import AsyncOpenAI
32
  from sqlalchemy.orm import Session
33
 
34
  from .. import models
35
+ from .model_registry import model_registry
36
 
37
  # --- 신규: 이미지 설명을 위한 프롬프트 템플릿 추가 ---
38
  figure_prompt = """
 
239
  model_choice: 사용할 모델 선택 (기본값: "SmartEyeSsen")
240
  auto_load: True이면 초기화 시 자동으로 모델 로드 (기본값: False, 하위 호환성 유지)
241
  """
 
242
  self.device = device
243
  self.model_choice = model_choice
244
+ self.model_registry = model_registry
245
+ self._model_handle = None
246
  self._model_loaded = False
247
 
248
  # 자동 로드 옵션이 활성화된 경우 즉시 모델 로드
249
  if auto_load:
250
  self._ensure_model_loaded()
251
 
252
+ def _ensure_model_loaded(self, model_choice: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  """
254
  Lazy Loading: 모델이 로드되지 않았으면 자동으로 로드
255
  (다중 페이지 처리 시 모델을 한 번만 로드하도록 최적화)
256
  """
257
+ target_model = model_choice or self.model_choice
258
+ if (
259
+ self._model_loaded
260
+ and self._model_handle is not None
261
+ and self._model_handle.name == target_model
262
+ ):
263
+ return self._model_handle
264
+
265
+ handle = self.model_registry.get_model(target_model, device=self.device)
266
+ self._model_handle = handle
267
+ self.model_choice = target_model
268
  self._model_loaded = True
269
+ return handle
270
 
271
  def analyze_layout(
272
  self,
 
298
  self._model_loaded = False
299
 
300
  # Lazy Loading: 모델이 없으면 자동 로드
301
+ handle = self._ensure_model_loaded(active_model)
302
+ model = handle.model
303
+ model_spec = handle.spec
304
 
305
  logger.info("레이아웃 분석 시작...")
306
  temp_path = "temp_image.jpg"
307
  cv2.imwrite(temp_path, image)
308
 
309
+ imgsz, conf = model_spec.imgsz, model_spec.conf
 
 
 
 
 
310
 
311
+ results = model.predict(
312
  temp_path, imgsz=imgsz, conf=conf, iou=0.45, device=self.device
313
  )
314
 
315
  boxes = results[0].boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2]
316
  classes = results[0].boxes.cls.cpu().numpy()
317
  confs = results[0].boxes.conf.cpu().numpy()
318
+ class_names = model.names # 클래스 ID → 이름
319
 
320
  detection_records: List[Dict[str, float]] = []
321
 
app/services/batch_analysis.py CHANGED
@@ -38,7 +38,7 @@ import time
38
  from contextlib import asynccontextmanager
39
  from datetime import datetime
40
  from pathlib import Path
41
- from typing import Any, Dict, List, Optional
42
 
43
  import aiofiles
44
  import cv2
@@ -49,6 +49,7 @@ from sqlalchemy.orm import Session, selectinload
49
 
50
  from ..models import LayoutElement, Page, Project
51
  from .analysis_service import AnalysisService
 
52
  from .formatter import TextFormatter
53
  from .mock_models import MockElement
54
  from .sorter import save_sorting_results_to_db, sort_layout_elements
@@ -67,6 +68,51 @@ DEFAULT_MAX_CONCURRENT_PAGES = int(os.getenv("MAX_CONCURRENT_PAGES", "8")) # CP
67
  _model_instances: Dict[str, AnalysisService] = {}
68
  _model_lock = threading.Lock()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def _get_analysis_service(model_choice: str = "SmartEyeSsen") -> AnalysisService:
72
  """
@@ -515,6 +561,7 @@ async def analyze_project_batch_async(
515
  use_ai_descriptions: bool = True,
516
  api_key: Optional[str] = None,
517
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
 
518
  ) -> Dict[str, Any]:
519
  """
520
  프로젝트 내 'pending' 상태 페이지를 순차적으로 분석하고 결과 요약을 반환합니다.
@@ -555,7 +602,14 @@ async def analyze_project_batch_async(
555
  _update_project_status(project, "in_progress")
556
  db.commit()
557
 
558
- analysis_service = _get_analysis_service()
 
 
 
 
 
 
 
559
  formatter = TextFormatter(
560
  doc_type_id=project.doc_type_id,
561
  db=db,
@@ -612,6 +666,7 @@ def analyze_project_batch(
612
  use_ai_descriptions: bool = True,
613
  api_key: Optional[str] = None,
614
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
 
615
  ) -> Dict[str, Any]:
616
  """
617
  동기 컨텍스트 호환용 래퍼.
@@ -623,6 +678,7 @@ def analyze_project_batch(
623
  use_ai_descriptions=use_ai_descriptions,
624
  api_key=api_key,
625
  ai_max_concurrency=ai_max_concurrency,
 
626
  )
627
  )
628
 
@@ -635,6 +691,7 @@ async def analyze_project_batch_async_parallel(
635
  api_key: Optional[str] = None,
636
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
637
  max_concurrent_pages: int = 8,
 
638
  ) -> Dict[str, Any]:
639
  """
640
  프로젝트 내 'pending' 상태 페이지를 병렬로 분석하고 결과 요약을 반환합니다.
@@ -696,7 +753,14 @@ async def analyze_project_batch_async_parallel(
696
  _update_project_status(project, "in_progress")
697
  db.commit()
698
 
699
- analysis_service = _get_analysis_service()
 
 
 
 
 
 
 
700
  formatter = TextFormatter(
701
  doc_type_id=project.doc_type_id,
702
  db=db,
@@ -798,6 +862,7 @@ def analyze_project_batch_parallel(
798
  api_key: Optional[str] = None,
799
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
800
  max_concurrent_pages: int = DEFAULT_MAX_CONCURRENT_PAGES,
 
801
  ) -> Dict[str, Any]:
802
  """
803
  동기 컨텍스트 호환용 래퍼 (병렬 처리 버전).
@@ -810,6 +875,7 @@ def analyze_project_batch_parallel(
810
  api_key=api_key,
811
  ai_max_concurrency=ai_max_concurrency,
812
  max_concurrent_pages=max_concurrent_pages,
 
813
  )
814
  )
815
 
@@ -823,4 +889,6 @@ __all__ = [
823
  "_process_single_page",
824
  "_process_single_page_async",
825
  "DEFAULT_AI_CONCURRENCY",
 
 
826
  ]
 
38
  from contextlib import asynccontextmanager
39
  from datetime import datetime
40
  from pathlib import Path
41
+ from typing import Any, Dict, List, Optional, Set
42
 
43
  import aiofiles
44
  import cv2
 
49
 
50
  from ..models import LayoutElement, Page, Project
51
  from .analysis_service import AnalysisService
52
+ from .model_registry import model_registry
53
  from .formatter import TextFormatter
54
  from .mock_models import MockElement
55
  from .sorter import save_sorting_results_to_db, sort_layout_elements
 
68
  _model_instances: Dict[str, AnalysisService] = {}
69
  _model_lock = threading.Lock()
70
 
71
+ # 문서 타입별 기본 모델 매핑
72
+ DOC_TYPE_MODEL_MAP = {
73
+ 1: "SmartEyeSsen",
74
+ 2: "docstructbench",
75
+ }
76
+ DEFAULT_MODEL_CHOICE = "SmartEyeSsen"
77
+
78
+
79
+ def _available_model_names() -> Set[str]:
80
+ return set(model_registry.list_registered().keys())
81
+
82
+
83
+ def is_supported_model(model_name: str) -> bool:
84
+ return model_name in _available_model_names()
85
+
86
+
87
+ def resolve_model_choice(
88
+ doc_type_id: Optional[int],
89
+ requested_model: Optional[str] = None,
90
+ ) -> str:
91
+ """
92
+ doc_type 또는 사용자 요청에 맞는 모델명을 반환합니다.
93
+
94
+ Args:
95
+ doc_type_id: document_types.doc_type_id
96
+ requested_model: 사용자가 명시적으로 지정한 모델 이름
97
+
98
+ Raises:
99
+ ValueError: 지원되지 않는 모델명이 요청된 경우
100
+ """
101
+ if requested_model:
102
+ if not is_supported_model(requested_model):
103
+ raise ValueError(f"지원하지 않는 AI 모델입니다: {requested_model}")
104
+ return requested_model
105
+
106
+ if doc_type_id in DOC_TYPE_MODEL_MAP:
107
+ return DOC_TYPE_MODEL_MAP[doc_type_id]
108
+
109
+ logger.warning(
110
+ "알 수 없는 doc_type_id ({})에 대해 기본 모델({})을 사용합니다.",
111
+ doc_type_id,
112
+ DEFAULT_MODEL_CHOICE,
113
+ )
114
+ return DEFAULT_MODEL_CHOICE
115
+
116
 
117
  def _get_analysis_service(model_choice: str = "SmartEyeSsen") -> AnalysisService:
118
  """
 
561
  use_ai_descriptions: bool = True,
562
  api_key: Optional[str] = None,
563
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
564
+ analysis_model: Optional[str] = None,
565
  ) -> Dict[str, Any]:
566
  """
567
  프로젝트 내 'pending' 상태 페이지를 순차적으로 분석하고 결과 요약을 반환합니다.
 
602
  _update_project_status(project, "in_progress")
603
  db.commit()
604
 
605
+ model_choice = resolve_model_choice(project.doc_type_id, analysis_model)
606
+ logger.info(
607
+ "프로젝트 분석 모델 선택: project_id={}, doc_type_id={}, model={}",
608
+ project.project_id,
609
+ project.doc_type_id,
610
+ model_choice,
611
+ )
612
+ analysis_service = _get_analysis_service(model_choice)
613
  formatter = TextFormatter(
614
  doc_type_id=project.doc_type_id,
615
  db=db,
 
666
  use_ai_descriptions: bool = True,
667
  api_key: Optional[str] = None,
668
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
669
+ analysis_model: Optional[str] = None,
670
  ) -> Dict[str, Any]:
671
  """
672
  동기 컨텍스트 호환용 래퍼.
 
678
  use_ai_descriptions=use_ai_descriptions,
679
  api_key=api_key,
680
  ai_max_concurrency=ai_max_concurrency,
681
+ analysis_model=analysis_model,
682
  )
683
  )
684
 
 
691
  api_key: Optional[str] = None,
692
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
693
  max_concurrent_pages: int = 8,
694
+ analysis_model: Optional[str] = None,
695
  ) -> Dict[str, Any]:
696
  """
697
  프로젝트 내 'pending' 상태 페이지를 병렬로 분석하고 결과 요약을 반환합니다.
 
753
  _update_project_status(project, "in_progress")
754
  db.commit()
755
 
756
+ model_choice = resolve_model_choice(project.doc_type_id, analysis_model)
757
+ logger.info(
758
+ "병렬 프로젝트 분석 모델 선택: project_id={}, doc_type_id={}, model={}",
759
+ project.project_id,
760
+ project.doc_type_id,
761
+ model_choice,
762
+ )
763
+ analysis_service = _get_analysis_service(model_choice)
764
  formatter = TextFormatter(
765
  doc_type_id=project.doc_type_id,
766
  db=db,
 
862
  api_key: Optional[str] = None,
863
  ai_max_concurrency: int = DEFAULT_AI_CONCURRENCY,
864
  max_concurrent_pages: int = DEFAULT_MAX_CONCURRENT_PAGES,
865
+ analysis_model: Optional[str] = None,
866
  ) -> Dict[str, Any]:
867
  """
868
  동기 컨텍스트 호환용 래퍼 (병렬 처리 버전).
 
875
  api_key=api_key,
876
  ai_max_concurrency=ai_max_concurrency,
877
  max_concurrent_pages=max_concurrent_pages,
878
+ analysis_model=analysis_model,
879
  )
880
  )
881
 
 
889
  "_process_single_page",
890
  "_process_single_page_async",
891
  "DEFAULT_AI_CONCURRENCY",
892
+ "is_supported_model",
893
+ "resolve_model_choice",
894
  ]
app/services/model_registry.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from shutil import copy2
7
+ from threading import Lock
8
+ from typing import Dict, Iterable, Optional
9
+
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from loguru import logger
13
+
14
+
15
+ try:
16
+ from doclayout_yolo import YOLOv10
17
+ except ImportError as exc: # pragma: no cover - 환경 의존
18
+ YOLOv10 = None # type: ignore[assignment]
19
+ _IMPORT_ERROR = exc
20
+ else:
21
+ _IMPORT_ERROR = None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class ModelSpec:
26
+ name: str
27
+ repo_id: str
28
+ filename: str
29
+ imgsz: int = 1024
30
+ conf: float = 0.25
31
+
32
+
33
+ @dataclass
34
+ class ModelHandle:
35
+ name: str
36
+ spec: ModelSpec
37
+ model: "YOLOv10"
38
+ device: str
39
+ weight_path: Path
40
+
41
+
42
+ class ModelRegistry:
43
+ """
44
+ DocLayout-YOLO 계열 모델을 전역으로 캐싱/재사용하기 위한 레지스트리.
45
+ - 모델별 가중치 다운로드는 한 번만 수행
46
+ - 디바이스(CPU/GPU)별 인스턴스를 필요 시 별도로 유지
47
+ """
48
+
49
+ def __init__(self) -> None:
50
+ self._specs: Dict[str, ModelSpec] = {}
51
+ self._models: Dict[str, ModelHandle] = {}
52
+ self._locks: Dict[str, Lock] = {}
53
+ self._default_device = "cuda" if torch.cuda.is_available() else "cpu"
54
+
55
+ @staticmethod
56
+ def _make_key(name: str, device: str) -> str:
57
+ return f"{name}:{device}"
58
+
59
+ def register(self, spec: ModelSpec) -> None:
60
+ self._specs[spec.name] = spec
61
+ self._locks.setdefault(spec.name, Lock())
62
+ logger.debug(f"📘 모델 스펙 등록: {spec.name} (imgsz={spec.imgsz}, conf={spec.conf})")
63
+
64
+ def list_registered(self) -> Dict[str, ModelSpec]:
65
+ return dict(self._specs)
66
+
67
+ def preload(self, targets: Optional[Iterable[str]] = None, *, device: Optional[str] = None) -> None:
68
+ names = list(targets) if targets else list(self._specs.keys())
69
+ for name in names:
70
+ try:
71
+ self.get_model(name, device=device)
72
+ except Exception as exc: # pragma: no cover - 초기화 단계
73
+ logger.error(f"❌ 모델 프리로드 실패 ({name}): {exc}")
74
+ raise
75
+
76
+ def get_model(self, name: str, *, device: Optional[str] = None) -> ModelHandle:
77
+ if name not in self._specs:
78
+ raise KeyError(f"등록되지 않은 모델입니다: {name}")
79
+
80
+ if _IMPORT_ERROR is not None:
81
+ raise RuntimeError(
82
+ "doclayout_yolo 패키지가 설치되지 않아 모델을 로드할 수 없습니다."
83
+ ) from _IMPORT_ERROR
84
+
85
+ resolved_device = device or self._default_device
86
+ key = self._make_key(name, resolved_device)
87
+
88
+ if key in self._models:
89
+ return self._models[key]
90
+
91
+ lock = self._locks.setdefault(name, Lock())
92
+ with lock:
93
+ if key in self._models:
94
+ return self._models[key]
95
+
96
+ spec = self._specs[name]
97
+ weight_path = self._download_weights(name, spec)
98
+ model = self._load_model(weight_path, resolved_device)
99
+
100
+ handle = ModelHandle(
101
+ name=name,
102
+ spec=spec,
103
+ model=model,
104
+ device=resolved_device,
105
+ weight_path=weight_path,
106
+ )
107
+ self._models[key] = handle
108
+ logger.info(f"✅ 모델 로드 완료: {name} (device={resolved_device})")
109
+ return handle
110
+
111
+ @staticmethod
112
+ def _download_weights(name: str, spec: ModelSpec) -> Path:
113
+ override_env = os.getenv(f"{name.upper()}_MODEL_PATH")
114
+ if override_env:
115
+ override_path = Path(override_env)
116
+ if override_path.exists():
117
+ logger.info(f"📂 {name} 가중치 경로 override 사용: {override_path}")
118
+ return override_path.resolve()
119
+ logger.warning(
120
+ f"⚠️ {name.upper()}_MODEL_PATH 가 지정되었지만 파일을 찾을 수 없습니다: {override_path}"
121
+ )
122
+
123
+ cache_root = Path(
124
+ os.getenv("MODEL_CACHE_DIR", Path.home() / ".cache" / "smarteye_models")
125
+ ).resolve()
126
+ target_dir = (cache_root / name).resolve()
127
+ target_dir.mkdir(parents=True, exist_ok=True)
128
+ target_path = target_dir / spec.filename
129
+
130
+ if target_path.exists():
131
+ logger.debug(f"📦 캐시된 가중치 사용: {target_path}")
132
+ return target_path
133
+
134
+ logger.info(f"⬇️ {name} 가중치 다운로드 중 ({spec.repo_id}/{spec.filename})")
135
+ downloaded_path = hf_hub_download(
136
+ repo_id=spec.repo_id,
137
+ filename=spec.filename,
138
+ local_dir=str(target_dir),
139
+ local_dir_use_symlinks=False,
140
+ )
141
+
142
+ downloaded_path = Path(downloaded_path).resolve()
143
+ if downloaded_path != target_path:
144
+ copy2(downloaded_path, target_path)
145
+ logger.debug(f"📁 가중치 복사: {downloaded_path.name} -> {target_path}")
146
+
147
+ return target_path
148
+
149
+ @staticmethod
150
+ def _load_model(weight_path: Path, device: str) -> "YOLOv10":
151
+ if YOLOv10 is None: # pragma: no cover
152
+ raise RuntimeError("doclayout_yolo 패키지가 없습니다.")
153
+
154
+ logger.info(f"🧠 모델 로딩: {weight_path.name} (device={device})")
155
+ model = YOLOv10(str(weight_path), task="predict")
156
+ model.to(device)
157
+ if hasattr(model, "training"):
158
+ model.training = False
159
+ return model
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # 전역 레지스트리 인스턴스 및 기본 모델 스펙 등록
164
+ # ---------------------------------------------------------------------------
165
+ DEFAULT_MODEL_SPECS = [
166
+ ModelSpec(
167
+ name="SmartEyeSsen",
168
+ repo_id="AkJeond/SmartEye",
169
+ filename="best.pt",
170
+ imgsz=1024,
171
+ conf=0.25,
172
+ ),
173
+ ModelSpec(
174
+ name="docstructbench",
175
+ repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
176
+ filename="doclayout_yolo_docstructbench_imgsz1024.pt",
177
+ imgsz=1024,
178
+ conf=0.25,
179
+ )
180
+ ]
181
+
182
+ model_registry = ModelRegistry()
183
+ for spec in DEFAULT_MODEL_SPECS:
184
+ model_registry.register(spec)