from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import fitz import tempfile import os import base64 import re import json app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.post("/debug") async def debug_document(file: UploadFile = File(...)): """Show what PyMuPDF sees — helps diagnose image ordering.""" suffix = os.path.splitext(file.filename)[1].lower() with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: content = await file.read() tmp.write(content) tmp_path = tmp.name try: doc = fitz.open(tmp_path) result = {"pymupdf_version": fitz.VersionBind, "pages": []} for page_num, page in enumerate(doc): pg = {"page": page_num + 1, "text_blocks": [], "images": []} # Text blocks for b in page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]: if b["type"] == 0: text = "" for line in b.get("lines", []): for span in line.get("spans", []): text += span.get("text", "") text = text.strip() if text: pg["text_blocks"].append({ "y": round(b["bbox"][1], 1), "text": text[:60], }) # Images for img in page.get_images(full=True): xref = img[0] info = {"xref": xref} try: rects = page.get_image_rects(xref) if rects: r = rects[0] info["y"] = round(r.y0, 1) info["w"] = round(r.width, 1) info["h"] = round(r.height, 1) else: info["y"] = "no_rects" except Exception as e: info["y"] = f"error: {e}" # Also try get_image_info try: for ii in page.get_image_info(xrefs=True): if ii.get("xref") == xref: info["info_bbox"] = [round(x, 1) for x in ii.get("bbox", [])] break except Exception: pass pg["images"].append(info) result["pages"].append(pg) doc.close() return result except Exception as e: return {"error": str(e), "traceback": repr(e)} finally: os.unlink(tmp_path) def extract_pdf(tmp_path: str) -> dict: doc = fitz.open(tmp_path) page_count = len(doc) raw_elements = [] for page_num, page in enumerate(doc): page_height = page.rect.height page_width = page.rect.width header_zone = page_height * 0.08 footer_zone = page_height * 0.92 # Tables table_rects = [] try: found_tables = page.find_tables() for tab in found_tables: rect = fitz.Rect(tab.bbox) table_rects.append(rect) html = "" for row in tab.extract(): html += "" for cell in row: html += f"" html += "" html += "
{cell or ''}
" raw_elements.append((page_num, rect.y0, { "content": html, "type": "table", "page": page_num + 1, })) except Exception: pass # Images — try get_image_rects first, fall back to get_image_info seen_xrefs = set() for img in page.get_images(full=True): xref = img[0] if xref in seen_xrefs: continue seen_xrefs.add(xref) # Try to find y-position img_y = None img_w = 0 img_h = 0 # Method 1: get_image_rects (most reliable) try: rects = page.get_image_rects(xref) if rects: r = rects[0] img_y = r.y0 img_w = r.width img_h = r.height except Exception: pass # Method 2: get_image_info (fallback) if img_y is None: try: for ii in page.get_image_info(xrefs=True): if ii.get("xref") == xref: bbox = ii.get("bbox", (0, 0, 0, 0)) img_y = bbox[1] img_w = bbox[2] - bbox[0] img_h = bbox[3] - bbox[1] break except Exception: pass # Method 3: just put it after all text on this page if img_y is None or img_y == 0: img_y = page_height * 0.5 # middle of page as fallback # Filters if img_y < header_zone or (img_y + img_h) > footer_zone: continue if img_w < 50 or img_h < 50: continue if img_w > page_width * 0.8 and img_h < 30: continue try: base_img = doc.extract_image(xref) if not base_img or not base_img.get("image"): continue w = base_img.get("width", 0) h = base_img.get("height", 0) if w < 50 or h < 50: continue b64 = base64.b64encode(base_img["image"]).decode("utf-8") ext = base_img.get("ext", "png") raw_elements.append((page_num, img_y, { "content": f"data:image/{ext};base64,{b64}", "type": "image", "page": page_num + 1, "width": w, "height": h, })) except Exception: continue # Text blocks blocks = page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] for block in blocks: if block["type"] != 0: continue block_rect = fitz.Rect(block["bbox"]) if any(block_rect.intersects(tr) for tr in table_rects): continue y_pos = block_rect.y0 lines_data = [] for line in block.get("lines", []): lt = "" ms = 0 bold = False for span in line.get("spans", []): lt += span.get("text", "") if span.get("size", 0) > ms: ms = span["size"] if "bold" in span.get("font", "").lower(): bold = True lt = lt.strip() if lt: lines_data.append({"text": lt, "font_size": ms, "is_bold": bold}) if not lines_data: continue paragraphs = [] cp = [] cit = False for ld in lines_data: it = ( ld["font_size"] > 13 or (ld["is_bold"] and len(ld["text"]) < 120) or (ld["text"].isupper() and 2 < len(ld["text"]) < 100) ) if it != cit and cp: pt = " ".join(cp).strip() if pt: paragraphs.append({"text": pt, "is_title": cit}) cp = [] cit = it cp.append(ld["text"]) if cp: pt = " ".join(cp).strip() if pt: paragraphs.append({"text": pt, "is_title": cit}) for para in paragraphs: el_type = "title" if para["is_title"] else "text" text = para["text"] if el_type == "text" and len(text) > 500: sents = re.split(r'(?<=[.!?])\s+', text) chunk = "" offset = 0 for s in sents: if len(chunk) + len(s) > 400 and chunk: raw_elements.append((page_num, y_pos + offset * 0.01, { "content": chunk.strip(), "type": "text", "page": page_num + 1, })) offset += 1 chunk = s else: chunk += (" " if chunk else "") + s if chunk.strip(): raw_elements.append((page_num, y_pos + offset * 0.01, { "content": chunk.strip(), "type": "text", "page": page_num + 1, })) else: raw_elements.append((page_num, y_pos, { "content": text, "type": el_type, "page": page_num + 1, })) doc.close() raw_elements.sort(key=lambda x: (x[0], x[1])) all_elements = [] for idx, (pg, yp, el) in enumerate(raw_elements): el["index"] = idx all_elements.append(el) text_sections = [e for e in all_elements if e["type"] == "text"] tables = [e for e in all_elements if e["type"] == "table"] images = [e for e in all_elements if e["type"] == "image"] titles = [e for e in all_elements if e["type"] == "title"] return { "text_sections": text_sections, "tables": tables, "images": images, "titles": titles, "page_count": page_count, "counts": { "text_sections": len(text_sections), "tables": len(tables), "images": len(images), "titles": len(titles), } } def extract_docx(tmp_path: str) -> dict: from docx import Document as DocxDocument doc = DocxDocument(tmp_path) elements = [] idx = 0 for para in doc.paragraphs: text = para.text.strip() if not text: continue if para.style and para.style.name and "Heading" in para.style.name: elements.append({"index": idx, "content": text, "type": "title"}) else: elements.append({"index": idx, "content": text, "type": "text"}) idx += 1 for table in doc.tables: html = "" for row in table.rows: html += "" for cell in row.cells: html += f"" html += "" html += "
{cell.text}
" elements.append({"index": idx, "content": html, "type": "table"}) idx += 1 text_sections = [e for e in elements if e["type"] == "text"] tables_list = [e for e in elements if e["type"] == "table"] titles = [e for e in elements if e["type"] == "title"] return { "text_sections": text_sections, "tables": tables_list, "images": [], "titles": titles, "counts": {"text_sections": len(text_sections), "tables": len(tables_list), "images": 0, "titles": len(titles)} } @app.post("/partition") async def partition_document(file: UploadFile = File(...)): suffix = os.path.splitext(file.filename)[1].lower() with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: content = await file.read() tmp.write(content) tmp_path = tmp.name try: if suffix == ".pdf": return extract_pdf(tmp_path) elif suffix == ".docx": return extract_docx(tmp_path) else: with open(tmp_path, "r", encoding="utf-8", errors="ignore") as f: text = f.read() paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] return { "text_sections": [{"index": i, "content": p, "type": "text"} for i, p in enumerate(paragraphs)], "tables": [], "images": [], "titles": [], "counts": {"text_sections": len(paragraphs), "tables": 0, "images": 0, "titles": 0}, } except Exception as e: return {"error": str(e), "type": type(e).__name__} finally: os.unlink(tmp_path) @app.get("/health") def health(): return {"status": "ok"}