moazx commited on
Commit
b857db3
·
1 Parent(s): 6a3aab8

Reimplement

Browse files
Files changed (1) hide show
  1. hf_app.py +356 -4
hf_app.py CHANGED
@@ -1,7 +1,359 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.wsgi import WSGIMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- from app import app as flask_app
5
 
6
  app = FastAPI()
7
- app.mount("/", WSGIMiddleware(flask_app))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Dict, List
3
+
4
+ import json
5
+ import os
6
+ import shutil
7
+
8
+ import torch
9
+ from fastapi import FastAPI, File, Form, Request, UploadFile
10
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from fastapi.templating import Jinja2Templates
13
+ from loguru import logger
14
+ from werkzeug.utils import secure_filename
15
+
16
+ import main as extractor
17
 
 
18
 
19
  app = FastAPI()
20
+
21
+
22
+ # Static files and templates -------------------------------------------------
23
+ app.mount("/static", StaticFiles(directory="static"), name="static")
24
+ templates = Jinja2Templates(directory="templates")
25
+
26
+
27
+ def flask_like_url_for(endpoint: str, **kwargs: Any) -> str:
28
+ """Minimal Flask-like url_for for templates using filename= for static.
29
+
30
+ The Jinja template calls url_for('static', filename='css/styles.css'),
31
+ which is Flask style. We emulate that here so templates work unchanged.
32
+ """
33
+
34
+ if endpoint == "static":
35
+ filename = str(kwargs.get("filename", ""))
36
+ return "/static/" + filename.lstrip("/")
37
+
38
+ # Fallback: just return "/<endpoint>"; templates only use static.
39
+ return "/" + endpoint.lstrip("/")
40
+
41
+
42
+ templates.env.globals["url_for"] = flask_like_url_for
43
+
44
+
45
+ # Configuration -------------------------------------------------------------
46
+ UPLOAD_FOLDER = Path("./uploads")
47
+ OUTPUT_FOLDER = Path("./output")
48
+ MAX_CONTENT_LENGTH = 500 * 1024 * 1024 # 500MB
49
+
50
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
51
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
52
+
53
+
54
+ # Global model cache --------------------------------------------------------
55
+ _model: Any = None
56
+
57
+
58
+ def get_device_info() -> Dict[str, Any]:
59
+ """Get information about GPU/CPU availability."""
60
+
61
+ cuda_available = torch.cuda.is_available()
62
+ device = "cuda" if cuda_available else "cpu"
63
+
64
+ info: Dict[str, Any] = {
65
+ "device": device,
66
+ "cuda_available": cuda_available,
67
+ "device_name": None,
68
+ "device_count": 0,
69
+ }
70
+
71
+ if cuda_available:
72
+ info["device_name"] = torch.cuda.get_device_name(0)
73
+ info["device_count"] = torch.cuda.device_count()
74
+
75
+ return info
76
+
77
+
78
+ def load_model_once() -> Any:
79
+ """Load the DocLayout-YOLO model once and cache it in this process."""
80
+
81
+ global _model
82
+ if _model is None:
83
+ logger.info("Loading DocLayout-YOLO model...")
84
+ _model = extractor.get_model()
85
+ logger.info("Model loaded successfully")
86
+ return _model
87
+
88
+
89
+ # Routes --------------------------------------------------------------------
90
+ @app.get("/", response_class=HTMLResponse)
91
+ async def index(request: Request) -> HTMLResponse:
92
+ """Main page, equivalent to the Flask index route."""
93
+
94
+ device_info = get_device_info()
95
+ return templates.TemplateResponse(
96
+ "index.html", {"request": request, "device_info": device_info}
97
+ )
98
+
99
+
100
+ @app.get("/api/device-info")
101
+ async def device_info() -> Dict[str, Any]:
102
+ """API endpoint to get device information."""
103
+
104
+ return get_device_info()
105
+
106
+
107
+ @app.post("/api/upload")
108
+ async def upload_files(
109
+ request: Request,
110
+ files: List[UploadFile] = File(..., alias="files[]"),
111
+ extraction_mode: str = Form("images"),
112
+ ) -> JSONResponse:
113
+ """Handle multiple PDF file uploads (FastAPI version of Flask route)."""
114
+
115
+ if not files or all((f.filename or "") == "" for f in files):
116
+ return JSONResponse({"error": "No files selected"}, status_code=400)
117
+
118
+ include_images = extraction_mode != "markdown"
119
+ include_markdown = extraction_mode != "images"
120
+
121
+ results: List[Dict[str, Any]] = []
122
+
123
+ for upload in files:
124
+ filename = upload.filename or ""
125
+ if not filename.endswith(".pdf"):
126
+ continue
127
+
128
+ try:
129
+ safe_name = secure_filename(filename)
130
+ stem = Path(safe_name).stem
131
+
132
+ upload_path = UPLOAD_FOLDER / safe_name
133
+ # Save uploaded file to disk
134
+ with upload_path.open("wb") as out_f:
135
+ while True:
136
+ chunk = await upload.read(1024 * 1024)
137
+ if not chunk:
138
+ break
139
+ out_f.write(chunk)
140
+
141
+ # Prepare output directory
142
+ output_dir = OUTPUT_FOLDER / stem
143
+ output_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Move PDF into output directory
146
+ pdf_path = output_dir / safe_name
147
+ upload_path.replace(pdf_path)
148
+
149
+ # Process PDF
150
+ extractor.USE_MULTIPROCESSING = False
151
+ logger.info(
152
+ f"Processing {safe_name} (images={include_images}, markdown={include_markdown})"
153
+ )
154
+
155
+ if include_images:
156
+ load_model_once()
157
+
158
+ extractor.process_pdf_with_pool(
159
+ pdf_path,
160
+ output_dir,
161
+ pool=None,
162
+ extract_images=include_images,
163
+ extract_markdown=include_markdown,
164
+ )
165
+
166
+ # Collect results
167
+ json_path = output_dir / f"{stem}_content_list.json"
168
+ elements: List[Dict[str, Any]] = []
169
+ if include_images and json_path.exists():
170
+ elements = json.loads(json_path.read_text(encoding="utf-8"))
171
+
172
+ annotated_pdf: str | None = None
173
+ if include_images:
174
+ candidate_pdf = output_dir / f"{stem}_layout.pdf"
175
+ if candidate_pdf.exists():
176
+ annotated_pdf = str(candidate_pdf.relative_to(OUTPUT_FOLDER))
177
+
178
+ markdown_path: str | None = None
179
+ if include_markdown:
180
+ candidate_md = output_dir / f"{stem}.md"
181
+ if candidate_md.exists():
182
+ markdown_path = str(candidate_md.relative_to(OUTPUT_FOLDER))
183
+
184
+ figures = [e for e in elements if e.get("type") == "figure"]
185
+ tables = [e for e in elements if e.get("type") == "table"]
186
+
187
+ results.append(
188
+ {
189
+ "filename": safe_name,
190
+ "stem": stem,
191
+ "output_dir": str(output_dir.relative_to(OUTPUT_FOLDER)),
192
+ "figures_count": len(figures),
193
+ "tables_count": len(tables),
194
+ "elements_count": len(elements),
195
+ "annotated_pdf": annotated_pdf,
196
+ "markdown_path": markdown_path,
197
+ "include_images": include_images,
198
+ "include_markdown": include_markdown,
199
+ }
200
+ )
201
+
202
+ except Exception as e: # pragma: no cover - runtime error path
203
+ logger.error(f"Error processing {filename}: {e}")
204
+ results.append({"filename": filename, "error": str(e)})
205
+
206
+ return JSONResponse({"results": results})
207
+
208
+
209
+ @app.get("/api/pdf-list")
210
+ async def pdf_list() -> Dict[str, Any]:
211
+ """Get list of processed PDFs."""
212
+
213
+ pdfs: List[Dict[str, Any]] = []
214
+ output_dir = OUTPUT_FOLDER
215
+
216
+ if not output_dir.exists():
217
+ return {"pdfs": pdfs}
218
+
219
+ for item in output_dir.iterdir():
220
+ if item.is_dir():
221
+ json_files = list(item.glob("*_content_list.json"))
222
+ md_files = list(item.glob("*.md"))
223
+ pdf_files = list(item.glob("*.pdf"))
224
+
225
+ if json_files or md_files or pdf_files:
226
+ stem = item.name
227
+ pdfs.append(
228
+ {
229
+ "stem": stem,
230
+ "output_dir": str(item.relative_to(output_dir)),
231
+ }
232
+ )
233
+
234
+ return {"pdfs": pdfs}
235
+
236
+
237
+ @app.get("/api/pdf-details/{pdf_stem:path}")
238
+ async def pdf_details(pdf_stem: str) -> JSONResponse:
239
+ """Get detailed information about a processed PDF."""
240
+
241
+ output_dir = OUTPUT_FOLDER / pdf_stem
242
+
243
+ if not output_dir.exists():
244
+ return JSONResponse({"error": "PDF not found"}, status_code=404)
245
+
246
+ json_files = list(output_dir.glob("*_content_list.json"))
247
+ elements: List[Dict[str, Any]] = []
248
+ if json_files:
249
+ elements = json.loads(json_files[0].read_text(encoding="utf-8"))
250
+
251
+ figures = [e for e in elements if e.get("type") == "figure"]
252
+ tables = [e for e in elements if e.get("type") == "table"]
253
+
254
+ annotated_pdf: str | None = None
255
+ pdf_files = list(output_dir.glob("*_layout.pdf"))
256
+ if pdf_files:
257
+ annotated_pdf = str(pdf_files[0].relative_to(OUTPUT_FOLDER))
258
+
259
+ markdown_path: str | None = None
260
+ md_files = list(output_dir.glob("*.md"))
261
+ if md_files:
262
+ markdown_path = str(md_files[0].relative_to(OUTPUT_FOLDER))
263
+
264
+ figure_dir = output_dir / "figures"
265
+ table_dir = output_dir / "tables"
266
+
267
+ figure_images: List[str] = []
268
+ if figure_dir.exists():
269
+ figure_images = [
270
+ str(f.relative_to(OUTPUT_FOLDER)) for f in sorted(figure_dir.glob("*.png"))
271
+ ]
272
+
273
+ table_images: List[str] = []
274
+ if table_dir.exists():
275
+ table_images = [
276
+ str(t.relative_to(OUTPUT_FOLDER)) for t in sorted(table_dir.glob("*.png"))
277
+ ]
278
+
279
+ return JSONResponse(
280
+ {
281
+ "stem": pdf_stem,
282
+ "figures": figures,
283
+ "tables": tables,
284
+ "figures_count": len(figures),
285
+ "tables_count": len(tables),
286
+ "elements_count": len(elements),
287
+ "annotated_pdf": annotated_pdf,
288
+ "markdown_path": markdown_path,
289
+ "figure_images": figure_images,
290
+ "table_images": table_images,
291
+ }
292
+ )
293
+
294
+
295
+ @app.get("/output/{filename:path}")
296
+ async def output_file(filename: str) -> FileResponse | JSONResponse:
297
+ """Serve output files (PDFs, images, markdown)."""
298
+
299
+ output_root = OUTPUT_FOLDER.resolve()
300
+ file_path = (output_root / filename).resolve()
301
+
302
+ if output_root not in file_path.parents and file_path != output_root:
303
+ return JSONResponse({"error": "Invalid path"}, status_code=400)
304
+
305
+ if not file_path.exists() or not file_path.is_file():
306
+ return JSONResponse({"error": "Not found"}, status_code=404)
307
+
308
+ return FileResponse(file_path)
309
+
310
+
311
+ def _delete_by_stem(stem_raw: str) -> JSONResponse:
312
+ stem = (stem_raw or "").strip()
313
+ if not stem:
314
+ return JSONResponse({"error": "Missing stem"}, status_code=400)
315
+
316
+ output_root = OUTPUT_FOLDER.resolve()
317
+ target_dir = (output_root / stem).resolve()
318
+
319
+ if output_root not in target_dir.parents and target_dir != output_root:
320
+ return JSONResponse({"error": "Invalid stem path"}, status_code=400)
321
+
322
+ if not target_dir.exists() or not target_dir.is_dir():
323
+ return JSONResponse({"error": "Not found"}, status_code=404)
324
+
325
+ shutil.rmtree(target_dir, ignore_errors=False)
326
+ logger.info(f"Deleted processed output: {target_dir}")
327
+
328
+ return JSONResponse({"ok": True, "deleted": stem})
329
+
330
+
331
+ @app.post("/api/delete")
332
+ async def delete_pdf(request: Request, stem_form: str | None = Form(default=None)) -> JSONResponse:
333
+ """Delete a processed PDF directory by stem (JSON or form body)."""
334
+
335
+ try:
336
+ stem = (stem_form or "").strip()
337
+ if not stem:
338
+ data: Dict[str, Any] = {}
339
+ try:
340
+ data = await request.json()
341
+ except Exception:
342
+ data = {}
343
+ stem = (str(data.get("stem") or "")).strip()
344
+ return _delete_by_stem(stem)
345
+ except Exception as e: # pragma: no cover - runtime error path
346
+ logger.error(f"Delete failed: {e}")
347
+ return JSONResponse({"error": str(e)}, status_code=500)
348
+
349
+
350
+ @app.api_route("/api/delete/{stem:path}", methods=["POST", "GET"])
351
+ async def delete_pdf_by_path(stem: str) -> JSONResponse:
352
+ """Alternate endpoint to delete using URL path, for clients avoiding bodies."""
353
+
354
+ try:
355
+ return _delete_by_stem(stem)
356
+ except Exception as e: # pragma: no cover - runtime error path
357
+ logger.error(f"Delete failed: {e}")
358
+ return JSONResponse({"error": str(e)}, status_code=500)
359
+