RAG / app.py
vedavyas1235's picture
Update app.py
09b5cf4 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import fitz
import tempfile
import os
import base64
import re
import json
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.post("/debug")
async def debug_document(file: UploadFile = File(...)):
"""Show what PyMuPDF sees — helps diagnose image ordering."""
suffix = os.path.splitext(file.filename)[1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
doc = fitz.open(tmp_path)
result = {"pymupdf_version": fitz.VersionBind, "pages": []}
for page_num, page in enumerate(doc):
pg = {"page": page_num + 1, "text_blocks": [], "images": []}
# Text blocks
for b in page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]:
if b["type"] == 0:
text = ""
for line in b.get("lines", []):
for span in line.get("spans", []):
text += span.get("text", "")
text = text.strip()
if text:
pg["text_blocks"].append({
"y": round(b["bbox"][1], 1),
"text": text[:60],
})
# Images
for img in page.get_images(full=True):
xref = img[0]
info = {"xref": xref}
try:
rects = page.get_image_rects(xref)
if rects:
r = rects[0]
info["y"] = round(r.y0, 1)
info["w"] = round(r.width, 1)
info["h"] = round(r.height, 1)
else:
info["y"] = "no_rects"
except Exception as e:
info["y"] = f"error: {e}"
# Also try get_image_info
try:
for ii in page.get_image_info(xrefs=True):
if ii.get("xref") == xref:
info["info_bbox"] = [round(x, 1) for x in ii.get("bbox", [])]
break
except Exception:
pass
pg["images"].append(info)
result["pages"].append(pg)
doc.close()
return result
except Exception as e:
return {"error": str(e), "traceback": repr(e)}
finally:
os.unlink(tmp_path)
def extract_pdf(tmp_path: str) -> dict:
doc = fitz.open(tmp_path)
page_count = len(doc)
raw_elements = []
for page_num, page in enumerate(doc):
page_height = page.rect.height
page_width = page.rect.width
header_zone = page_height * 0.08
footer_zone = page_height * 0.92
# Tables
table_rects = []
try:
found_tables = page.find_tables()
for tab in found_tables:
rect = fitz.Rect(tab.bbox)
table_rects.append(rect)
html = "<table>"
for row in tab.extract():
html += "<tr>"
for cell in row:
html += f"<td>{cell or ''}</td>"
html += "</tr>"
html += "</table>"
raw_elements.append((page_num, rect.y0, {
"content": html, "type": "table", "page": page_num + 1,
}))
except Exception:
pass
# Images — try get_image_rects first, fall back to get_image_info
seen_xrefs = set()
for img in page.get_images(full=True):
xref = img[0]
if xref in seen_xrefs:
continue
seen_xrefs.add(xref)
# Try to find y-position
img_y = None
img_w = 0
img_h = 0
# Method 1: get_image_rects (most reliable)
try:
rects = page.get_image_rects(xref)
if rects:
r = rects[0]
img_y = r.y0
img_w = r.width
img_h = r.height
except Exception:
pass
# Method 2: get_image_info (fallback)
if img_y is None:
try:
for ii in page.get_image_info(xrefs=True):
if ii.get("xref") == xref:
bbox = ii.get("bbox", (0, 0, 0, 0))
img_y = bbox[1]
img_w = bbox[2] - bbox[0]
img_h = bbox[3] - bbox[1]
break
except Exception:
pass
# Method 3: just put it after all text on this page
if img_y is None or img_y == 0:
img_y = page_height * 0.5 # middle of page as fallback
# Filters
if img_y < header_zone or (img_y + img_h) > footer_zone:
continue
if img_w < 50 or img_h < 50:
continue
if img_w > page_width * 0.8 and img_h < 30:
continue
try:
base_img = doc.extract_image(xref)
if not base_img or not base_img.get("image"):
continue
w = base_img.get("width", 0)
h = base_img.get("height", 0)
if w < 50 or h < 50:
continue
b64 = base64.b64encode(base_img["image"]).decode("utf-8")
ext = base_img.get("ext", "png")
raw_elements.append((page_num, img_y, {
"content": f"data:image/{ext};base64,{b64}",
"type": "image", "page": page_num + 1,
"width": w, "height": h,
}))
except Exception:
continue
# Text blocks
blocks = page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
for block in blocks:
if block["type"] != 0:
continue
block_rect = fitz.Rect(block["bbox"])
if any(block_rect.intersects(tr) for tr in table_rects):
continue
y_pos = block_rect.y0
lines_data = []
for line in block.get("lines", []):
lt = ""
ms = 0
bold = False
for span in line.get("spans", []):
lt += span.get("text", "")
if span.get("size", 0) > ms:
ms = span["size"]
if "bold" in span.get("font", "").lower():
bold = True
lt = lt.strip()
if lt:
lines_data.append({"text": lt, "font_size": ms, "is_bold": bold})
if not lines_data:
continue
paragraphs = []
cp = []
cit = False
for ld in lines_data:
it = (
ld["font_size"] > 13 or
(ld["is_bold"] and len(ld["text"]) < 120) or
(ld["text"].isupper() and 2 < len(ld["text"]) < 100)
)
if it != cit and cp:
pt = " ".join(cp).strip()
if pt:
paragraphs.append({"text": pt, "is_title": cit})
cp = []
cit = it
cp.append(ld["text"])
if cp:
pt = " ".join(cp).strip()
if pt:
paragraphs.append({"text": pt, "is_title": cit})
for para in paragraphs:
el_type = "title" if para["is_title"] else "text"
text = para["text"]
if el_type == "text" and len(text) > 500:
sents = re.split(r'(?<=[.!?])\s+', text)
chunk = ""
offset = 0
for s in sents:
if len(chunk) + len(s) > 400 and chunk:
raw_elements.append((page_num, y_pos + offset * 0.01, {
"content": chunk.strip(),
"type": "text", "page": page_num + 1,
}))
offset += 1
chunk = s
else:
chunk += (" " if chunk else "") + s
if chunk.strip():
raw_elements.append((page_num, y_pos + offset * 0.01, {
"content": chunk.strip(),
"type": "text", "page": page_num + 1,
}))
else:
raw_elements.append((page_num, y_pos, {
"content": text, "type": el_type, "page": page_num + 1,
}))
doc.close()
raw_elements.sort(key=lambda x: (x[0], x[1]))
all_elements = []
for idx, (pg, yp, el) in enumerate(raw_elements):
el["index"] = idx
all_elements.append(el)
text_sections = [e for e in all_elements if e["type"] == "text"]
tables = [e for e in all_elements if e["type"] == "table"]
images = [e for e in all_elements if e["type"] == "image"]
titles = [e for e in all_elements if e["type"] == "title"]
return {
"text_sections": text_sections,
"tables": tables,
"images": images,
"titles": titles,
"page_count": page_count,
"counts": {
"text_sections": len(text_sections),
"tables": len(tables),
"images": len(images),
"titles": len(titles),
}
}
def extract_docx(tmp_path: str) -> dict:
from docx import Document as DocxDocument
doc = DocxDocument(tmp_path)
elements = []
idx = 0
for para in doc.paragraphs:
text = para.text.strip()
if not text:
continue
if para.style and para.style.name and "Heading" in para.style.name:
elements.append({"index": idx, "content": text, "type": "title"})
else:
elements.append({"index": idx, "content": text, "type": "text"})
idx += 1
for table in doc.tables:
html = "<table>"
for row in table.rows:
html += "<tr>"
for cell in row.cells:
html += f"<td>{cell.text}</td>"
html += "</tr>"
html += "</table>"
elements.append({"index": idx, "content": html, "type": "table"})
idx += 1
text_sections = [e for e in elements if e["type"] == "text"]
tables_list = [e for e in elements if e["type"] == "table"]
titles = [e for e in elements if e["type"] == "title"]
return {
"text_sections": text_sections, "tables": tables_list, "images": [],
"titles": titles,
"counts": {"text_sections": len(text_sections), "tables": len(tables_list), "images": 0, "titles": len(titles)}
}
@app.post("/partition")
async def partition_document(file: UploadFile = File(...)):
suffix = os.path.splitext(file.filename)[1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
if suffix == ".pdf":
return extract_pdf(tmp_path)
elif suffix == ".docx":
return extract_docx(tmp_path)
else:
with open(tmp_path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
return {
"text_sections": [{"index": i, "content": p, "type": "text"} for i, p in enumerate(paragraphs)],
"tables": [], "images": [], "titles": [],
"counts": {"text_sections": len(paragraphs), "tables": 0, "images": 0, "titles": 0},
}
except Exception as e:
return {"error": str(e), "type": type(e).__name__}
finally:
os.unlink(tmp_path)
@app.get("/health")
def health():
return {"status": "ok"}