| import os
|
| import re
|
| import json
|
| import math
|
| import tempfile
|
| import fitz
|
| import cv2
|
| import numpy as np
|
| from PIL import Image
|
| import streamlit as st
|
|
|
|
|
|
|
|
|
| DPI = 300
|
| OUT_DIR = "outputs"
|
|
|
| KEEP_ONLY_STRESS_STRAIN = False
|
|
|
| CAP_RE = re.compile(r"^(Fig\.?\s*\d+|Figure\s*\d+)\b", re.IGNORECASE)
|
| SS_KW = re.compile(
|
| r"(stress\s*[-–]?\s*strain|stress|strain|tensile|MPa|GPa|kN|yield|elongation)",
|
| re.IGNORECASE
|
| )
|
|
|
|
|
|
|
|
|
| def render_page(page, dpi=DPI):
|
| mat = fitz.Matrix(dpi/72, dpi/72)
|
| pix = page.get_pixmap(matrix=mat, alpha=False)
|
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| return img, mat
|
|
|
| def pdf_to_px_bbox(bbox_pdf, mat):
|
| x0, y0, x1, y1 = bbox_pdf
|
| sx, sy = mat.a, mat.d
|
| return (int(float(x0) * sx), int(float(y0) * sy), int(float(x1) * sx), int(float(y1) * sy))
|
|
|
| def safe_crop_px(pil_img, box):
|
| if not isinstance(box, (tuple, list)):
|
| return None
|
| if len(box) == 1 and isinstance(box[0], (tuple, list)) and len(box[0]) == 4:
|
| box = box[0]
|
| if len(box) != 4:
|
| return None
|
|
|
| x0, y0, x1, y1 = box
|
| if any(isinstance(v, (tuple, list)) for v in (x0, y0, x1, y1)):
|
| return None
|
|
|
| try:
|
| x0 = int(x0)
|
| y0 = int(y0)
|
| x1 = int(x1)
|
| y1 = int(y1)
|
| except (TypeError, ValueError):
|
| return None
|
|
|
| if x1 < x0:
|
| x0, x1 = x1, x0
|
| if y1 < y0:
|
| y0, y1 = y1, y0
|
|
|
| W, H = pil_img.size
|
| x0 = max(0, min(W, x0))
|
| x1 = max(0, min(W, x1))
|
| y0 = max(0, min(H, y0))
|
| y1 = max(0, min(H, y1))
|
| if x1 <= x0 or y1 <= y0:
|
| return None
|
| return pil_img.crop((x0, y0, x1, y1))
|
|
|
|
|
|
|
|
|
| def find_caption_blocks(page):
|
| caps = []
|
| blocks = page.get_text("blocks")
|
| for b in blocks:
|
| x0, y0, x1, y1, text = b[0], b[1], b[2], b[3], b[4]
|
| t = " ".join(str(text).strip().split())
|
| if CAP_RE.match(t):
|
| caps.append({"bbox": (x0, y0, x1, y1), "text": t})
|
| return caps
|
|
|
|
|
|
|
|
|
| def dhash64(pil_img):
|
| gray = pil_img.convert("L").resize((9, 8), Image.LANCZOS)
|
| pixels = list(gray.getdata())
|
| bits = 0
|
| for r in range(8):
|
| for c in range(8):
|
| left = pixels[r * 9 + c]
|
| right = pixels[r * 9 + c + 1]
|
| bits = (bits << 1) | (1 if left > right else 0)
|
| return bits
|
|
|
|
|
|
|
|
|
| def has_colorbar_like_strip(pil_img):
|
| img = np.array(pil_img)
|
| if img.ndim != 3:
|
| return False
|
| H, W, _ = img.shape
|
| if W < 250 or H < 150:
|
| return False
|
| strip_w = max(18, int(0.07 * W))
|
| strip = img[:, W-strip_w:W, :]
|
| q = (strip // 24).reshape(-1, 3)
|
| uniq = np.unique(q, axis=0)
|
| return len(uniq) > 70
|
|
|
| def texture_score(pil_img):
|
| gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
|
| lap = cv2.Laplacian(gray, cv2.CV_64F)
|
| return float(lap.var())
|
|
|
| def is_mostly_legend(pil_img):
|
| gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
|
| bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
| bw = cv2.medianBlur(bw, 3)
|
| H, W = bw.shape
|
| fill = float(np.count_nonzero(bw)) / float(H * W)
|
| return (0.03 < fill < 0.18) and (min(H, W) < 260)
|
|
|
|
|
|
|
|
|
| def detect_axes_lines(pil_img):
|
| gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
|
| edges = cv2.Canny(gray, 50, 150)
|
| H, W = gray.shape
|
| min_len = int(0.28 * min(H, W))
|
|
|
| lines = cv2.HoughLinesP(
|
| edges, 1, np.pi/180,
|
| threshold=90,
|
| minLineLength=min_len,
|
| maxLineGap=14
|
| )
|
| if lines is None:
|
| return None, None
|
|
|
| horizontals, verticals = [], []
|
| for x1, y1, x2, y2 in lines[:, 0]:
|
| dx, dy = abs(x2-x1), abs(y2-y1)
|
| length = math.hypot(dx, dy)
|
| if dy < 18 and dx > 0.35 * W:
|
| horizontals.append((length, (x1, y1, x2, y2)))
|
| if dx < 18 and dy > 0.35 * H:
|
| verticals.append((length, (x1, y1, x2, y2)))
|
|
|
| if not horizontals or not verticals:
|
| return None, None
|
|
|
| horizontals.sort(key=lambda t: t[0], reverse=True)
|
| verticals.sort(key=lambda t: t[0], reverse=True)
|
| return horizontals[0][1], verticals[0][1]
|
|
|
| def axis_intersection_ok(x_axis, y_axis, W, H):
|
| xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
|
| ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
|
| if not (0 <= xa_y < H and 0 <= ya_x < W):
|
| return False
|
| if ya_x > int(0.95 * W) or xa_y < int(0.05 * H):
|
| return False
|
| return True
|
|
|
| def tick_text_presence_score(pil_img, x_axis, y_axis):
|
| img = np.array(pil_img)
|
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
|
| bw = cv2.medianBlur(bw, 3)
|
|
|
| H, W = gray.shape
|
| xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
|
| ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
|
|
|
| y0a = max(0, xa_y - 40)
|
| y1a = min(H, xa_y + 110)
|
| x_roi = bw[y0a:y1a, 0:W]
|
|
|
| x0b = max(0, ya_x - 180)
|
| x1b = min(W, ya_x + 50)
|
| y_roi = bw[0:H, x0b:x1b]
|
|
|
| def count_small_components(mask):
|
| num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
| cnt = 0
|
| for i in range(1, num):
|
| x, y, w, h, area = stats[i]
|
| if 4 <= w <= 150 and 4 <= h <= 150 and 20 <= area <= 5000:
|
| cnt += 1
|
| return cnt
|
|
|
| return count_small_components(x_roi) + count_small_components(y_roi)
|
|
|
| def is_real_plot(pil_img):
|
| if has_colorbar_like_strip(pil_img):
|
| return False
|
| if is_mostly_legend(pil_img):
|
| return False
|
|
|
| x_axis, y_axis = detect_axes_lines(pil_img)
|
| if x_axis is None or y_axis is None:
|
| return False
|
|
|
| arr = np.array(pil_img)
|
| H, W = arr.shape[0], arr.shape[1]
|
| if not axis_intersection_ok(x_axis, y_axis, W, H):
|
| return False
|
|
|
| if texture_score(pil_img) > 2200:
|
| return False
|
|
|
| score = tick_text_presence_score(pil_img, x_axis, y_axis)
|
| return score >= 18
|
|
|
|
|
|
|
|
|
| def connected_components_boxes(pil_img):
|
| img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| mask = (gray < 245).astype(np.uint8) * 255
|
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8), iterations=2)
|
| num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
|
|
| boxes = []
|
| for i in range(1, num):
|
| x, y, w, h, area = stats[i]
|
| boxes.append((int(area), (int(x), int(y), int(x + w), int(y + h))))
|
| boxes.sort(key=lambda t: t[0], reverse=True)
|
| return boxes
|
|
|
| def expand_box(box, W, H, left=0.10, right=0.06, top=0.06, bottom=0.18):
|
| x0, y0, x1, y1 = box
|
| bw = x1 - x0
|
| bh = y1 - y0
|
| ex0 = max(0, int(x0 - left * bw))
|
| ex1 = min(W, int(x1 + right * bw))
|
| ey0 = max(0, int(y0 - top * bh))
|
| ey1 = min(H, int(y1 + bottom * bh))
|
| return (ex0, ey0, ex1, ey1)
|
|
|
|
|
|
|
|
|
| def crop_plot_from_caption(page_img, cap_bbox_pdf, mat):
|
| cap_px = pdf_to_px_bbox(cap_bbox_pdf, mat)
|
| cap_y0 = cap_px[1]
|
| cap_y1 = cap_px[3]
|
|
|
| W, H = page_img.size
|
| search_top = max(0, cap_y0 - int(0.95 * H))
|
| search_bot = min(H, cap_y1 + int(0.20 * H))
|
| region = safe_crop_px(page_img, (0, search_top, W, search_bot))
|
| if region is None:
|
| return None
|
|
|
| comps = connected_components_boxes(region)
|
| best = None
|
| best_area = -1
|
|
|
| for area, box in comps[:35]:
|
| x0, y0, x1, y1 = box
|
| bw = x1 - x0
|
| bh = y1 - y0
|
| if bw < 220 or bh < 180:
|
| continue
|
|
|
| exp = expand_box(box, region.size[0], region.size[1])
|
| cand = safe_crop_px(region, exp)
|
| if cand is None:
|
| continue
|
|
|
| if not is_real_plot(cand):
|
| continue
|
|
|
| if area > best_area:
|
| best_area = area
|
| best = cand
|
|
|
| return best
|
|
|
|
|
|
|
|
|
| def run_extraction(pdf_path, paper_id="uploaded_paper"):
|
| out_paper = os.path.join(OUT_DIR, paper_id)
|
| out_imgs = os.path.join(out_paper, "plots_with_axes")
|
| os.makedirs(out_imgs, exist_ok=True)
|
|
|
| doc = fitz.open(pdf_path)
|
| results = []
|
| seen = set()
|
| saved = 0
|
|
|
| for p in range(len(doc)):
|
| page = doc[p]
|
| caps = find_caption_blocks(page)
|
| if not caps:
|
| continue
|
|
|
| page_img, mat = render_page(page, dpi=DPI)
|
|
|
| for cap in caps:
|
| cap_text = cap["text"]
|
|
|
| if KEEP_ONLY_STRESS_STRAIN and not SS_KW.search(cap_text):
|
| continue
|
|
|
| fig = crop_plot_from_caption(page_img, cap["bbox"], mat)
|
| if fig is None:
|
| continue
|
|
|
| if fig.size[0] > 8 and fig.size[1] > 8:
|
| fig = fig.crop((2, 2, fig.size[0]-2, fig.size[1]-2))
|
|
|
| try:
|
| h = dhash64(fig)
|
| except Exception:
|
| continue
|
|
|
| if h in seen:
|
| continue
|
| seen.add(h)
|
|
|
| img_name = f"p{p+1:02d}_{saved:04d}.png"
|
| img_path = os.path.join(out_imgs, img_name)
|
| fig.save(img_path)
|
|
|
| results.append({
|
| "page": p + 1,
|
| "caption": cap_text,
|
| "image": img_path
|
| })
|
| saved += 1
|
|
|
| out_json = os.path.join(out_paper, "plots_with_axes.json")
|
| with open(out_json, "w", encoding="utf-8") as f:
|
| json.dump(results, f, indent=2, ensure_ascii=False)
|
|
|
| return results, out_json
|
|
|
| def main():
|
| st.set_page_config(page_title="Research Paper Plot Extractor", layout="wide")
|
| st.title(" Plot Extractor (Upload PDF)")
|
|
|
| uploaded = st.file_uploader("Upload a research paper PDF", type=["pdf"])
|
| if not uploaded:
|
| st.info("Upload a PDF to extract plots.")
|
| return
|
|
|
| paper_id = os.path.splitext(uploaded.name)[0].replace(" ", "_")
|
|
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
| pdf_path = os.path.join(tmpdir, uploaded.name)
|
| with open(pdf_path, "wb") as f:
|
| f.write(uploaded.read())
|
|
|
| with st.spinner("Extracting plots..."):
|
| results, out_json = run_extraction(pdf_path, paper_id=paper_id)
|
|
|
| st.success(f"Extracted {len(results)} plots.")
|
|
|
|
|
| for r in results:
|
| st.markdown(f"**Page {r['page']}** — {r['caption']}")
|
| st.image(r["image"], use_container_width=True)
|
| st.divider()
|
|
|
|
|
| st.subheader("JSON Output")
|
| st.json(results)
|
|
|
| with open(out_json, "rb") as f:
|
| st.download_button(
|
| "Download JSON",
|
| data=f,
|
| file_name=os.path.basename(out_json),
|
| mime="application/json"
|
| )
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|