gvlktejaswi's picture
Update src/pages/categorized/page6.py
111bd64 verified
import os
import re
import json
import tempfile
import zipfile
from io import BytesIO
import base64
from typing import Dict, Any, Optional
from collections import defaultdict
import fitz # PyMuPDF
import cv2
import numpy as np
import streamlit as st
import pandas as pd
import requests
# ============================================================
# GEMINI CONFIG (HF-safe)
# - DO NOT hardcode keys in repo
# - Put GEMINI_API_KEY in Hugging Face Space → Settings → Secrets
# - This code prints full 403/4xx body to debug permission/quota
# ============================================================
MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") # stable default
API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent"
# Schema you provided (kept same)
SCHEMA = {
"type": "OBJECT",
"properties": {
"material_name": {"type": "STRING"},
"material_abbreviation": {"type": "STRING"},
"mechanical_properties": {
"type": "ARRAY",
"items": {
"type": "OBJECT",
"properties": {
"section": {"type": "STRING"},
"property_name": {"type": "STRING"},
"value": {"type": "STRING"},
"unit": {"type": "STRING"},
"english": {"type": "STRING"},
"test_condition": {"type": "STRING"},
"comments": {"type": "STRING"},
},
"required": ["section", "property_name", "value", "english", "comments"],
},
},
},
}
# ============================================================
# Helpers
# ============================================================
def make_abbreviation(name: str) -> str:
if not name:
return "UNKNOWN"
words = name.split()
abbr = "".join(w[0] for w in words if w and w[0].isalpha()).upper()
return abbr or name[:6].upper()
def call_gemini_from_bytes(pdf_bytes: bytes, filename: str) -> Optional[Dict[str, Any]]:
"""
Calls Gemini API with PDF bytes.
Prints full HTTP error body (esp. 403) so you can see exact reason in HF logs/UI.
"""
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
st.error(" GEMINI_API_KEY not found. Add it in Hugging Face Space → Settings → Secrets, then restart Space.")
st.stop()
try:
encoded_file = base64.b64encode(pdf_bytes).decode("utf-8")
mime_type = "application/pdf"
except Exception as e:
st.error(f"Error encoding PDF: {e}")
return None
prompt = (
"You are an expert materials scientist. From the attached PDF, extract the material name, "
"abbreviation, and ALL properties across categories (Mechanical, Thermal, Electrical, Physical, "
"Optical, Rheological, etc.). Return them as 'mechanical_properties' (a single list). "
"For each property, you MUST extract:\n"
"- section (category)\n- property_name\n- value (or range)\n- unit\n"
"- english (converted or alternate units, e.g., psi, °F, inches; write '' if not provided)\n"
"- test_condition\n- comments (include any notes, footnotes, standards, remarks; write '' if none)\n"
"All fields including english and comments are REQUIRED. Respond ONLY with valid JSON following the schema."
)
payload = {
"contents": [{
"parts": [
{"text": prompt},
{"inlineData": {"mimeType": mime_type, "data": encoded_file}}
]
}],
"generationConfig": {
"temperature": 0,
"responseMimeType": "application/json",
"responseSchema": SCHEMA
}
}
try:
r = requests.post(
API_URL,
params={"key": API_KEY},
json=payload,
timeout=300
)
# Always show full details for failures (especially 403)
if not r.ok:
st.error(f"Gemini HTTP {r.status_code}")
st.code(r.text) # <-- contains exact PERMISSION_DENIED / BILLING / QUOTA reason
return None
data = r.json()
candidates = data.get("candidates", [])
if not candidates:
st.warning("Gemini returned no candidates.")
return None
parts = candidates[0].get("content", {}).get("parts", [])
json_text = None
for p in parts:
t = p.get("text", "")
if t.strip().startswith("{"):
json_text = t
break
if not json_text:
st.warning("Gemini response didn't contain JSON text.")
st.code(json.dumps(data, indent=2)[:5000])
return None
return json.loads(json_text)
except Exception as e:
st.error(f"Gemini API exception: {e}")
return None
def convert_to_dataframe(data: Dict[str, Any]) -> pd.DataFrame:
mat_name = data.get("material_name", "") or ""
mat_abbr = data.get("material_abbreviation", "") or ""
if not mat_abbr:
mat_abbr = make_abbreviation(mat_name)
rows = []
for item in data.get("mechanical_properties", []):
rows.append({
"material_name": mat_name,
"material_abbreviation": mat_abbr,
"section": item.get("section", "") or "Mechanical",
"property_name": item.get("property_name", "") or "Unknown property",
"value": item.get("value", "") or "N/A",
"unit": item.get("unit", "") or "",
"english": item.get("english", "") or "",
"test_condition": item.get("test_condition", "") or "",
"comments": item.get("comments", "") or "",
})
return pd.DataFrame(rows)
# ============================================================
# Plot Extraction (offline)
# ============================================================
DPI = 300
CAP_RE = re.compile(r"^(Fig\.?\s*\d+|Figure\s*\d+)\b", re.IGNORECASE)
def get_page_image(page):
pix = page.get_pixmap(matrix=fitz.Matrix(DPI / 72, DPI / 72))
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.h, pix.w, 3)
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def is_valid_plot_geometry(binary_crop):
h, w = binary_crop.shape
if h < 100 or w < 100:
return False
ink_density = cv2.countNonZero(binary_crop) / float(w * h)
if ink_density > 0.35:
return False
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(10, w // 4), 1))
v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(10, h // 4)))
has_h = cv2.countNonZero(cv2.erode(binary_crop, h_kernel, iterations=1)) > 0
has_v = cv2.countNonZero(cv2.erode(binary_crop, v_kernel, iterations=1)) > 0
return has_h or has_v
def merge_boxes(rects):
if not rects:
return []
rects = sorted(rects, key=lambda r: r[2] * r[3], reverse=True)
merged = []
for r in rects:
rx, ry, rw, rh = r
if not any(
rx >= m[0] - 15 and ry >= m[1] - 15 and
rx + rw <= m[0] + m[2] + 15 and
ry + rh <= m[1] + m[3] + 15
for m in merged
):
merged.append(r)
return merged
def extract_images(pdf_doc):
grouped_data = defaultdict(lambda: {"page": 0, "image_data": []})
PADDING = 30
for page_num, page in enumerate(pdf_doc, start=1):
img_bgr = get_page_image(page)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 225, 255, cv2.THRESH_BINARY_INV)
kernel = np.ones((10, 10), np.uint8)
dilated = cv2.dilate(binary, kernel, iterations=1)
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
candidates = []
page_h, page_w = gray.shape
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
if 0.03 < (w * h) / float(page_w * page_h) < 0.8:
if is_valid_plot_geometry(binary[y:y + h, x:x + w]):
candidates.append((x, y, w, h))
final_rects = merge_boxes(candidates)
blocks = page.get_text("blocks")
for (cx, cy, cw, ch) in final_rects:
best_caption = f"Figure on Page {page_num} (Unlabeled)"
min_dist = float("inf")
for b in blocks:
if len(b) < 5:
continue
text = (b[4] or "").strip()
if CAP_RE.match(text):
cap_y = b[1] * (DPI / 72)
dist = cap_y - (cy + ch)
if 0 < dist < (page_h * 0.3) and dist < min_dist:
best_caption = text.replace("\n", " ")
min_dist = dist
x1, y1 = max(0, cx - PADDING), max(0, cy - PADDING)
x2, y2 = min(page_w, cx + cw + PADDING), min(page_h, cy + ch + PADDING)
crop = img_bgr[int(y1):int(y2), int(x1):int(x2)]
ok, buffer = cv2.imencode(".png", crop)
if not ok:
continue
img_bytes = buffer.tobytes()
fname = f"pg{page_num}_{cx}_{cy}.png"
grouped_data[best_caption]["page"] = page_num
grouped_data[best_caption]["image_data"].append({
"filename": fname,
"bytes": img_bytes,
"array": crop
})
return [{"caption": k, "page": v["page"], "image_data": v["image_data"]} for k, v in grouped_data.items()]
def create_zip(results, include_json=True):
buf = BytesIO()
with zipfile.ZipFile(buf, "w") as z:
if include_json:
json_data = [{"caption": r["caption"], "page": r["page"],
"image_count": len(r["image_data"])} for r in results]
z.writestr("plot_data.json", json.dumps(json_data, indent=4))
for item in results:
for img_data in item["image_data"]:
z.writestr(img_data["filename"], img_data["bytes"])
buf.seek(0)
return buf.getvalue()
# ============================================================
# UI: Your input form (kept same)
# ============================================================
def input_form():
PROPERTY_CATEGORIES = {
"Polymer": ["Thermal", "Mechanical", "Processing", "Physical", "Descriptive"],
"Fiber": ["Mechanical", "Physical", "Thermal", "Descriptive"],
"Composite": [
"Mechanical", "Thermal", "Processing", "Physical", "Descriptive",
"Composition / Reinforcement", "Architecture / Structure"
],
}
PROPERTY_NAMES = {
"Polymer": {
"Thermal": [
"Glass transition temperature (Tg)",
"Melting temperature (Tm)",
"Crystallization temperature (Tc)",
"Degree of crystallinity",
"Decomposition temperature",
],
"Mechanical": [
"Tensile modulus",
"Tensile strength",
"Elongation at break",
"Flexural modulus",
"Impact strength",
],
"Processing": [
"Melt flow index (MFI)",
"Processing temperature",
"Cooling rate",
"Mold shrinkage",
],
"Physical": ["Density", "Specific gravity"],
"Descriptive": ["Material grade", "Manufacturer"],
},
"Fiber": {
"Mechanical": ["Tensile modulus", "Tensile strength", "Strain to failure"],
"Physical": ["Density", "Fiber diameter"],
"Thermal": ["Decomposition temperature"],
"Descriptive": ["Fiber type", "Surface treatment"],
},
"Composite": {
"Mechanical": [
"Longitudinal modulus (E1)",
"Transverse modulus (E2)",
"Shear modulus (G12)",
"Poissons ratio (V12)",
"Tensile strength (fiber direction)",
"Interlaminar shear strength",
],
"Thermal": [
"Glass transition temperature (matrix)",
"Coefficient of thermal expansion (CTE)",
],
"Processing": ["Curing temperature", "Curing pressure"],
"Physical": ["Density"],
"Descriptive": ["Laminate type"],
"Composition / Reinforcement": [
"Fiber volume fraction",
"Fiber weight fraction",
"Fiber type",
"Matrix type",
],
"Architecture / Structure": [
"Weave type",
"Ply orientation",
"Number of plies",
"Stacking sequence",
],
},
}
st.title("Materials Property Input Form")
material_class = st.selectbox(
"Select Material Class",
("Polymer", "Fiber", "Composite"),
index=None,
placeholder="Choose material class",
)
property_category = None
if material_class:
property_category = st.selectbox(
"Select Property Category",
PROPERTY_CATEGORIES[material_class],
index=None,
placeholder="Choose property category",
)
property_name = None
if material_class and property_category:
property_name = st.selectbox(
"Select Property",
PROPERTY_NAMES[material_class][property_category],
index=None,
placeholder="Choose property",
)
if material_class and property_category and property_name:
with st.form("user_input"):
st.subheader("Enter Data")
material_name = st.text_input("Material Name")
material_abbr = st.text_input("Material Abbreviation")
value = st.text_input("Value")
unit = st.text_input("Unit (SI)")
english = st.text_input("English Units")
test_condition = st.text_input("Test Condition")
comments = st.text_area("Comments")
submitted = st.form_submit_button("Submit")
if submitted:
if not (material_name and value):
st.error("Material name and value are required.")
else:
df_row = pd.DataFrame([{
"material_class": material_class,
"material_name": material_name,
"material_abbreviation": material_abbr,
"section": property_category,
"property_name": property_name,
"value": value,
"unit": unit,
"english_units": english,
"test_condition": test_condition,
"comments": comments
}])
st.success("Property added successfully")
st.dataframe(df_row)
if "user_uploaded_data" not in st.session_state:
st.session_state["user_uploaded_data"] = df_row
else:
st.session_state["user_uploaded_data"] = pd.concat(
[st.session_state["user_uploaded_data"], df_row],
ignore_index=True
)
return
# ============================================================
# Main App
# ============================================================
def main():
st.set_page_config(page_title="PDF Data & Image Extractor", layout="wide")
# Sidebar image width (fixes your img_width undefined bug)
img_width = st.sidebar.slider("Image Width", 120, 900, 320, key="img_width")
# State init
if "image_results" not in st.session_state:
st.session_state.image_results = []
if "pdf_processed" not in st.session_state:
st.session_state.pdf_processed = False
if "current_pdf_name" not in st.session_state:
st.session_state.current_pdf_name = None
if "form_submitted" not in st.session_state:
st.session_state.form_submitted = False
if "pdf_data_extracted" not in st.session_state:
st.session_state.pdf_data_extracted = False
if "pdf_extracted_df" not in st.session_state:
st.session_state.pdf_extracted_df = pd.DataFrame()
if "pdf_extracted_meta" not in st.session_state:
st.session_state.pdf_extracted_meta = {}
# Track form submission changes
prev_uploaded_count = len(st.session_state.get("user_uploaded_data", pd.DataFrame()))
input_form()
curr_uploaded_count = len(st.session_state.get("user_uploaded_data", pd.DataFrame()))
if curr_uploaded_count > prev_uploaded_count:
st.session_state.form_submitted = True
st.title("PDF Material Data & Plot Extractor")
uploaded_file = st.file_uploader("Upload PDF (Material Datasheet or Research Paper)", type=["pdf"])
if not uploaded_file:
st.info("Upload a PDF to extract material data and plots")
st.session_state.pdf_processed = False
st.session_state.current_pdf_name = None
st.session_state.image_results = []
st.session_state.form_submitted = False
st.session_state.pdf_data_extracted = False
st.session_state.pdf_extracted_df = pd.DataFrame()
st.session_state.pdf_extracted_meta = {}
return
paper_id = os.path.splitext(uploaded_file.name)[0].replace(" ", "_")
# Reset per-PDF
if st.session_state.current_pdf_name != uploaded_file.name:
st.session_state.pdf_processed = False
st.session_state.current_pdf_name = uploaded_file.name
st.session_state.image_results = []
st.session_state.form_submitted = False
st.session_state.pdf_data_extracted = False
st.session_state.pdf_extracted_df = pd.DataFrame()
st.session_state.pdf_extracted_meta = {}
# Tabs
tab1, tab2 = st.tabs([" Material Data", " Extracted Plots"])
with tempfile.TemporaryDirectory() as tmpdir:
pdf_path = os.path.join(tmpdir, uploaded_file.name)
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# ---------------- Tab 1: Gemini extraction ----------------
with tab1:
st.subheader("Material Properties Data")
if not st.session_state.pdf_data_extracted:
with st.spinner("Extracting material data..."):
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
data = call_gemini_from_bytes(pdf_bytes, uploaded_file.name)
if data:
df = convert_to_dataframe(data)
if not df.empty:
st.session_state.pdf_extracted_df = df
st.session_state.pdf_extracted_meta = data
st.session_state.pdf_data_extracted = True
else:
st.warning("No data extracted (empty dataframe).")
else:
st.error("Failed to extract data from PDF (see error details above).")
df = st.session_state.pdf_extracted_df
meta = st.session_state.get("pdf_extracted_meta", {})
if not df.empty:
st.success(f"Extracted {len(df)} properties")
col1, col2 = st.columns(2)
with col1:
st.metric("Material", meta.get("material_name", "N/A"))
with col2:
st.metric("Abbreviation", meta.get("material_abbreviation", "N/A"))
st.dataframe(df, use_container_width=True, height=400)
st.subheader("Assign Material Category")
extracted_material_class = st.selectbox(
"Select category for this material",
["Polymer", "Fiber", "Composite"],
index=None,
placeholder="Required before adding to database"
)
if st.button("Add to Database"):
if not extracted_material_class:
st.error("Please select a material category before adding.")
else:
df2 = df.copy()
df2["material_class"] = extracted_material_class
df2["material_type"] = extracted_material_class
if "user_uploaded_data" not in st.session_state:
st.session_state["user_uploaded_data"] = df2
else:
st.session_state["user_uploaded_data"] = pd.concat(
[st.session_state["user_uploaded_data"], df2],
ignore_index=True
)
st.success(f"Added to {extracted_material_class} database!")
csv = df.to_csv(index=False)
st.download_button(
" Download CSV",
data=csv,
file_name=f"{paper_id}_data.csv",
mime="text/csv"
)
# ---------------- Tab 2: Plot extraction ----------------
with tab2:
st.subheader("Extracted Plot Images")
if not st.session_state.pdf_processed:
with st.spinner("Extracting plots from PDF..."):
doc = fitz.open(pdf_path)
st.session_state.image_results = extract_images(doc)
doc.close()
st.session_state.pdf_processed = True
if st.session_state.image_results:
subtab1, subtab2 = st.tabs([" Images", " JSON Preview"])
with subtab1:
st.success(f"Extracted {len(st.session_state.image_results)} plot groups")
col_img, col_json, col_all = st.columns(3)
with col_img:
img_zip = create_zip(st.session_state.image_results, include_json=False)
st.download_button(
"Download Images Only",
data=img_zip,
file_name=f"{paper_id}_images.zip",
mime="application/zip",
use_container_width=True,
key="download_images"
)
with col_json:
json_data = [{"caption": r["caption"], "page": r["page"],
"image_count": len(r["image_data"])} for r in st.session_state.image_results]
st.download_button(
"Download JSON",
data=json.dumps(json_data, indent=4),
file_name=f"{paper_id}_metadata.json",
mime="application/json",
use_container_width=True,
key="download_json_top"
)
with col_all:
full_zip = create_zip(st.session_state.image_results, include_json=True)
st.download_button(
"Download All",
data=full_zip,
file_name=f"{paper_id}_complete.zip",
mime="application/zip",
use_container_width=True,
key="download_all"
)
st.divider()
# Render groups + allow delete
results_copy = st.session_state.image_results.copy()
for idx in range(len(results_copy)):
if idx >= len(st.session_state.image_results):
break
r = st.session_state.image_results[idx]
with st.container(border=True):
col_cap, col_btn = st.columns([0.85, 0.15])
col_cap.markdown(f"**Page {r['page']}** {r['caption']}")
if col_btn.button("Delete", key=f"del_g_{idx}_{r['page']}"):
del st.session_state.image_results[idx]
st.rerun()
image_data_list = r["image_data"]
if image_data_list:
cols = st.columns(len(image_data_list))
for p_idx in range(len(image_data_list)):
if p_idx >= len(st.session_state.image_results[idx]["image_data"]):
break
img_data = st.session_state.image_results[idx]["image_data"][p_idx]
with cols[p_idx]:
st.image(img_data["array"], width=img_width, channels="BGR")
if st.button("Remove", key=f"del_s_{idx}_{p_idx}_{r['page']}"):
del st.session_state.image_results[idx]["image_data"][p_idx]
if len(st.session_state.image_results[idx]["image_data"]) == 0:
del st.session_state.image_results[idx]
st.rerun()
with subtab2:
st.subheader("Metadata Preview")
json_data = [{
"caption": r["caption"],
"page": r["page"],
"image_count": len(r["image_data"]),
"images": [img["filename"] for img in r["image_data"]],
} for r in st.session_state.image_results]
st.download_button(
"Download JSON",
data=json.dumps(json_data, indent=4),
file_name=f"{paper_id}_metadata.json",
mime="application/json",
key="download_json_bottom"
)
st.json(json_data)
else:
st.warning("No plots found in PDF")
if __name__ == "__main__":
main()