gvlktejaswi's picture
Upload 24 files
bad95eb verified
import os
import re
import json
import math
import tempfile
import fitz # PyMuPDF
import cv2
import numpy as np
from PIL import Image
import streamlit as st
import pandas as pd
import requests
import base64
from typing import Dict, Any, Optional
API_KEY = "AIzaSyAruLR2WyiaL9PquOXOhHF4wMn7tfYZWek"
API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-09-2025:generateContent?key={API_KEY}"
SCHEMA = {
"type": "OBJECT",
"properties": {
"material_name": {"type": "STRING"},
"material_abbreviation": {"type": "STRING"},
"mechanical_properties": {
"type": "ARRAY",
"items": {
"type": "OBJECT",
"properties": {
"section": {"type": "STRING"},
"property_name": {"type": "STRING"},
"value": {"type": "STRING"},
"unit": {"type": "STRING"},
"english": {"type": "STRING"},
"test_condition": {"type": "STRING"},
"comments": {"type": "STRING"}
},
"required": ["section", "property_name", "value", "english", "comments"]
}
}
}
}
def make_abbreviation(name: str) -> str:
"""Create a simple abbreviation from the material name."""
if not name:
return "UNKNOWN"
words = name.split()
abbr = "".join(w[0] for w in words if w and w[0].isalpha()).upper()
return abbr or name[:6].upper()
DPI = 300
OUT_DIR = "outputs"
KEEP_ONLY_STRESS_STRAIN = False
CAP_RE = re.compile(r"^(Fig\.?\s*\d+|Figure\s*\d+)\b", re.IGNORECASE)
SS_KW = re.compile(
r"(stress\s*[-–]?\s*strain|stress|strain|tensile|MPa|GPa|kN|yield|elongation)",
re.IGNORECASE
)
def call_gemini_from_bytes(pdf_bytes: bytes, filename: str) -> Optional[Dict[str, Any]]:
"""Calls Gemini API with PDF bytes"""
try:
encoded_file = base64.b64encode(pdf_bytes).decode("utf-8")
mime_type = "application/pdf"
except Exception as e:
st.error(f"Error encoding PDF: {e}")
return None
prompt = (
"You are an expert materials scientist. From the attached PDF, extract the material name, "
"abbreviation, and ALL properties across categories (Mechanical, Thermal, Electrical, Physical, "
"Optical, Rheological, etc.). Return them as 'mechanical_properties' (a single list). "
"For each property, you MUST extract:\n"
"- property_name\n- value (or range)\n- unit\n"
"- english (converted or alternate units, e.g., psi, °F, inches; write '' if not provided)\n"
"- test_condition\n- comments (include any notes, footnotes, standards, remarks; write '' if none)\n"
"All fields including english and comments are REQUIRED. Respond ONLY with valid JSON following the schema."
)
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inlineData": {"mimeType": mime_type, "data": encoded_file}}
]
}],
"generationConfig": {
"temperature": 0,
"responseMimeType": "application/json",
"responseSchema": SCHEMA
}
}
try:
r = requests.post(API_URL, json=payload, timeout=300)
r.raise_for_status()
data = r.json()
candidates = data.get("candidates", [])
if not candidates:
return None
parts = candidates[0].get("content", {}).get("parts", [])
json_text = None
for p in parts:
t = p.get("text", "")
if t.strip().startswith("{"):
json_text = t
break
return json.loads(json_text) if json_text else None
except Exception as e:
st.error(f"Gemini API Error: {e}")
return None
# def convert_to_dataframe(data: Dict[str, Any]) -> pd.DataFrame:
# """Convert extracted JSON to DataFrame"""
# rows = []
# for item in data.get("mechanical_properties", []):
# rows.append({
# "material_name": data.get("material_name", ""),
# "material_abbreviation": data.get("material_abbreviation", ""),
# "section": item.get("section", ""),
# "property_name": item.get("property_name", ""),
# "value": item.get("value", ""),
# "unit": item.get("unit", ""),
# "english": item.get("english", ""),
# "test_condition": item.get("test_condition", ""),
# "comments": item.get("comments", "")
# })
# return pd.DataFrame(rows)
def convert_to_dataframe(data: Dict[str, Any]) -> pd.DataFrame:
"""Convert extracted JSON to DataFrame, ensuring abbreviation is not empty."""
mat_name = data.get("material_name", "") or ""
mat_abbr = data.get("material_abbreviation", "") or ""
if not mat_abbr:
mat_abbr = make_abbreviation(mat_name)
rows = []
for item in data.get("mechanical_properties", []):
rows.append({
"material_name": mat_name,
"material_abbreviation": mat_abbr,
"section": item.get("section", "") or "Mechanical",
"property_name": item.get("property_name", "") or "Unknown property",
"value": item.get("value", "") or "N/A",
"unit": item.get("unit", "") or "",
"english": item.get("english", "") or "",
"test_condition": item.get("test_condition", "") or "",
"comments": item.get("comments", "") or "",
})
return pd.DataFrame(rows)
def render_page(page, dpi=DPI):
mat = fitz.Matrix(dpi/72, dpi/72)
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img, mat
def pdf_to_px_bbox(bbox_pdf, mat):
x0, y0, x1, y1 = bbox_pdf
sx, sy = mat.a, mat.d
return (int(float(x0) * sx), int(float(y0) * sy), int(float(x1) * sx), int(float(y1) * sy))
def safe_crop_px(pil_img, box):
if not isinstance(box, (tuple, list)):
return None
if len(box) == 1 and isinstance(box[0], (tuple, list)) and len(box[0]) == 4:
box = box[0]
if len(box) != 4:
return None
x0, y0, x1, y1 = box
if any(isinstance(v, (tuple, list)) for v in (x0, y0, x1, y1)):
return None
try:
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
except (TypeError, ValueError):
return None
if x1 < x0: x0, x1 = x1, x0
if y1 < y0: y0, y1 = y1, y0
W, H = pil_img.size
x0 = max(0, min(W, x0))
x1 = max(0, min(W, x1))
y0 = max(0, min(H, y0))
y1 = max(0, min(H, y1))
if x1 <= x0 or y1 <= y0:
return None
return pil_img.crop((x0, y0, x1, y1))
def find_caption_blocks(page):
caps = []
blocks = page.get_text("blocks")
for b in blocks:
x0, y0, x1, y1, text = b[0], b[1], b[2], b[3], b[4]
t = " ".join(str(text).strip().split())
if CAP_RE.match(t):
caps.append({"bbox": (x0, y0, x1, y1), "text": t})
return caps
def dhash64(pil_img):
gray = pil_img.convert("L").resize((9, 8), Image.LANCZOS)
pixels = list(gray.getdata())
bits = 0
for r in range(8):
for c in range(8):
left = pixels[r * 9 + c]
right = pixels[r * 9 + c + 1]
bits = (bits << 1) | (1 if left > right else 0)
return bits
def has_colorbar_like_strip(pil_img):
img = np.array(pil_img)
if img.ndim != 3:
return False
H, W, _ = img.shape
if W < 250 or H < 150:
return False
strip_w = max(18, int(0.07 * W))
strip = img[:, W-strip_w:W, :]
q = (strip // 24).reshape(-1, 3)
uniq = np.unique(q, axis=0)
return len(uniq) > 70
def texture_score(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_64F)
return float(lap.var())
def is_mostly_legend(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
bw = cv2.medianBlur(bw, 3)
H, W = bw.shape
fill = float(np.count_nonzero(bw)) / float(H * W)
return (0.03 < fill < 0.18) and (min(H, W) < 260)
def detect_axes_lines(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
H, W = gray.shape
min_len = int(0.28 * min(H, W))
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=90, minLineLength=min_len, maxLineGap=14)
if lines is None:
return None, None
horizontals, verticals = [], []
for x1, y1, x2, y2 in lines[:, 0]:
dx, dy = abs(x2-x1), abs(y2-y1)
length = math.hypot(dx, dy)
if dy < 18 and dx > 0.35 * W:
horizontals.append((length, (x1, y1, x2, y2)))
if dx < 18 and dy > 0.35 * H:
verticals.append((length, (x1, y1, x2, y2)))
if not horizontals or not verticals:
return None, None
horizontals.sort(key=lambda t: t[0], reverse=True)
verticals.sort(key=lambda t: t[0], reverse=True)
return horizontals[0][1], verticals[0][1]
def axis_intersection_ok(x_axis, y_axis, W, H):
xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
if not (0 <= xa_y < H and 0 <= ya_x < W):
return False
if ya_x > int(0.95 * W) or xa_y < int(0.05 * H):
return False
return True
def tick_text_presence_score(pil_img, x_axis, y_axis):
img = np.array(pil_img)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
bw = cv2.medianBlur(bw, 3)
H, W = gray.shape
xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
y0a = max(0, xa_y - 40)
y1a = min(H, xa_y + 110)
x_roi = bw[y0a:y1a, 0:W]
x0b = max(0, ya_x - 180)
x1b = min(W, ya_x + 50)
y_roi = bw[0:H, x0b:x1b]
def count_small_components(mask):
num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
cnt = 0
for i in range(1, num):
x, y, w, h, area = stats[i]
if 4 <= w <= 150 and 4 <= h <= 150 and 20 <= area <= 5000:
cnt += 1
return cnt
return count_small_components(x_roi) + count_small_components(y_roi)
def is_real_plot(pil_img):
if has_colorbar_like_strip(pil_img):
return False
if is_mostly_legend(pil_img):
return False
x_axis, y_axis = detect_axes_lines(pil_img)
if x_axis is None or y_axis is None:
return False
arr = np.array(pil_img)
H, W = arr.shape[0], arr.shape[1]
if not axis_intersection_ok(x_axis, y_axis, W, H):
return False
if texture_score(pil_img) > 2200:
return False
score = tick_text_presence_score(pil_img, x_axis, y_axis)
return score >= 18
def connected_components_boxes(pil_img):
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
mask = (gray < 245).astype(np.uint8) * 255
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8), iterations=2)
num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
boxes = []
for i in range(1, num):
x, y, w, h, area = stats[i]
boxes.append((int(area), (int(x), int(y), int(x + w), int(y + h))))
boxes.sort(key=lambda t: t[0], reverse=True)
return boxes
def expand_box(box, W, H, left=0.10, right=0.06, top=0.06, bottom=0.18):
x0, y0, x1, y1 = box
bw = x1 - x0
bh = y1 - y0
ex0 = max(0, int(x0 - left * bw))
ex1 = min(W, int(x1 + right * bw))
ey0 = max(0, int(y0 - top * bh))
ey1 = min(H, int(y1 + bottom * bh))
return (ex0, ey0, ex1, ey1)
def crop_plot_from_caption(page_img, cap_bbox_pdf, mat):
cap_px = pdf_to_px_bbox(cap_bbox_pdf, mat)
cap_y0 = cap_px[1]
cap_y1 = cap_px[3]
W, H = page_img.size
search_top = max(0, cap_y0 - int(0.95 * H))
search_bot = min(H, cap_y1 + int(0.20 * H))
region = safe_crop_px(page_img, (0, search_top, W, search_bot))
if region is None:
return None
comps = connected_components_boxes(region)
best = None
best_area = -1
for area, box in comps[:35]:
x0, y0, x1, y1 = box
bw = x1 - x0
bh = y1 - y0
if bw < 220 or bh < 180:
continue
exp = expand_box(box, region.size[0], region.size[1])
cand = safe_crop_px(region, exp)
if cand is None:
continue
if not is_real_plot(cand):
continue
if area > best_area:
best_area = area
best = cand
return best
def extract_images(pdf_path, paper_id="uploaded_paper"):
"""Extract plot images from PDF"""
out_paper = os.path.join(OUT_DIR, paper_id)
out_imgs = os.path.join(out_paper, "plots_with_axes")
os.makedirs(out_imgs, exist_ok=True)
doc = fitz.open(pdf_path)
results = []
seen = set()
saved = 0
for p in range(len(doc)):
page = doc[p]
caps = find_caption_blocks(page)
if not caps:
continue
page_img, mat = render_page(page, dpi=DPI)
for cap in caps:
cap_text = cap["text"]
if KEEP_ONLY_STRESS_STRAIN and not SS_KW.search(cap_text):
continue
fig = crop_plot_from_caption(page_img, cap["bbox"], mat)
if fig is None:
continue
if fig.size[0] > 8 and fig.size[1] > 8:
fig = fig.crop((2, 2, fig.size[0]-2, fig.size[1]-2))
try:
h = dhash64(fig)
except Exception:
continue
if h in seen:
continue
seen.add(h)
img_name = f"p{p+1:02d}_{saved:04d}.png"
img_path = os.path.join(out_imgs, img_name)
fig.save(img_path)
results.append({
"page": p + 1,
"caption": cap_text,
"image": img_path
})
saved += 1
return results
def input_form():
PROPERTY_CATEGORIES = {
"Polymer": [
"Thermal",
"Mechanical",
"Processing",
"Physical",
"Descriptive",
],
"Fiber": [
"Mechanical",
"Physical",
"Thermal",
"Descriptive",
],
"Composite": [
"Mechanical",
"Thermal",
"Processing",
"Physical",
"Descriptive",
"Composition / Reinforcement",
"Architecture / Structure",
],
}
PROPERTY_NAMES = {
"Polymer": {
"Thermal": [
"Glass transition temperature (Tg)",
"Melting temperature (Tm)",
"Crystallization temperature (Tc)",
"Degree of crystallinity",
"Decomposition temperature",
],
"Mechanical": [
"Tensile modulus",
"Tensile strength",
"Elongation at break",
"Flexural modulus",
"Impact strength",
],
"Processing": [
"Melt flow index (MFI)",
"Processing temperature",
"Cooling rate",
"Mold shrinkage",
],
"Physical": [
"Density",
"Specific gravity",
],
"Descriptive": [
"Material grade",
"Manufacturer",
],
},
"Fiber": {
"Mechanical": [
"Tensile modulus",
"Tensile strength",
"Strain to failure",
],
"Physical": [
"Density",
"Fiber diameter",
],
"Thermal": [
"Decomposition temperature",
],
"Descriptive": [
"Fiber type",
"Surface treatment",
],
},
"Composite": {
"Mechanical": [
"Longitudinal modulus (E1)",
"Transverse modulus (E2)",
"Shear modulus (G12)",
"Poissons ratio (V12)",
"Tensile strength (fiber direction)",
"Interlaminar shear strength",
],
"Thermal": [
"Glass transition temperature (matrix)",
"Coefficient of thermal expansion (CTE)",
],
"Processing": [
"Curing temperature",
"Curing pressure",
],
"Physical": [
"Density",
],
"Descriptive": [
"Laminate type",
],
"Composition / Reinforcement": [
"Fiber volume fraction",
"Fiber weight fraction",
"Fiber type",
"Matrix type",
],
"Architecture / Structure": [
"Weave type",
"Ply orientation",
"Number of plies",
"Stacking sequence",
],
},
}
st.title("Materials Property Input Form")
material_class = st.selectbox(
"Select Material Class",
("Polymer", "Fiber", "Composite"),
index=None,
placeholder="Choose material class",
)
if material_class:
property_category = st.selectbox(
"Select Property Category",
PROPERTY_CATEGORIES[material_class],
index=None,
placeholder="Choose property category",
)
else:
property_category = None
if material_class and property_category:
property_name = st.selectbox(
"Select Property",
PROPERTY_NAMES[material_class][property_category],
index=None,
placeholder="Choose property",
)
else:
property_name = None
if material_class and property_category and property_name:
with st.form("user_input"):
st.subheader("Enter Data")
material_name = st.text_input("Material Name")
material_abbr = st.text_input("Material Abbreviation")
value = st.text_input("Value")
unit = st.text_input("Unit (SI)")
english = st.text_input("English Units")
test_condition = st.text_input("Test Condition")
comments = st.text_area("Comments")
submitted = st.form_submit_button("Submit")
if submitted:
if not (material_name and value):
st.error("Material name and value are required.")
else:
Input_db = pd.DataFrame([{
"material_class": material_class,
"material_name": material_name,
"material_abbreviation": material_abbr,
"section": property_category,
"property_name": property_name,
"value": value,
"unit": unit,
"english_units": english,
"test_condition": test_condition,
"comments": comments
}])
st.success("Property added successfully")
st.dataframe(Input_db)
if "user_uploaded_data" not in st.session_state:
st.session_state["user_uploaded_data"] = Input_db
else:
st.session_state["user_uploaded_data"] = pd.concat(
[st.session_state["user_uploaded_data"], Input_db],
ignore_index=True
)
def main():
input_form()
st.set_page_config(page_title="PDF Data & Image Extractor", layout="wide")
st.title("PDF Material Data & Plot Extractor")
uploaded_file = st.file_uploader("Upload PDF (Material Datasheet or Research Paper)", type=["pdf"])
if not uploaded_file:
st.info("Upload a PDF to extract material data and plots")
return
paper_id = os.path.splitext(uploaded_file.name)[0].replace(" ", "_")
tab1, tab2 = st.tabs([" Material Data", " Extracted Plots"])
with tempfile.TemporaryDirectory() as tmpdir:
pdf_path = os.path.join(tmpdir, uploaded_file.name)
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with tab1:
st.subheader("Material Properties Data")
with st.spinner(" Extracting material data..."):
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
data = call_gemini_from_bytes(pdf_bytes, uploaded_file.name)
if data:
df = convert_to_dataframe(data)
if not df.empty:
st.success(f"Extracted {len(df)} properties")
col1, col2 = st.columns(2)
with col1:
st.metric("Material", data.get("material_name", "N/A"))
with col2:
st.metric("Abbreviation", data.get("material_abbreviation", "N/A"))
st.dataframe(df, use_container_width=True, height=400)
st.subheader("Assign Material Category")
extracted_material_class = st.selectbox(
"Select category for this material",
["Polymer", "Fiber", "Composite"],
index=None,
placeholder="Required before adding to database"
)
if st.button(" Add to Database"):
if not extracted_material_class:
st.error("Please select a material category before adding.")
else:
df["material_class"] = extracted_material_class
if "user_uploaded_data" not in st.session_state:
st.session_state["user_uploaded_data"] = df
else:
st.session_state["user_uploaded_data"] = pd.concat(
[st.session_state["user_uploaded_data"], df],
ignore_index=True
)
st.success(f"Added to {extracted_material_class} database!")
# if st.button(" Add to Database"):
# if "user_uploaded_data" not in st.session_state:
# st.session_state["user_uploaded_data"] = df
# else:
# st.session_state["user_uploaded_data"] = pd.concat(
# [st.session_state["user_uploaded_data"], df],
# ignore_index=True
# )
# st.success("Added to database!")
csv = df.to_csv(index=False)
st.download_button(
"Download CSV",
data=csv,
file_name=f"{paper_id}_data.csv",
mime="text/csv"
)
else:
st.warning("No data extracted")
else:
st.error("Failed to extract data from PDF")
with tab2:
st.subheader("Extracted Plot Images")
with st.spinner(" Extracting plots from PDF..."):
image_results = extract_images(pdf_path, paper_id=paper_id)
if image_results:
st.success(f" Extracted {len(image_results)} plots")
for r in image_results:
st.markdown(f"**Page {r['page']}** — {r['caption']}")
st.image(r["image"], use_container_width=True)
st.divider()
else:
st.warning("No plots found in PDF")
if __name__ == "__main__":
main()