AsyncBuilds's picture
Added main file
3d0b285 verified
import time
from io import BytesIO
import pandas as pd
import streamlit as st
from gliner2 import GLiNER2
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
PERSONAL_FIELDS = [
"Person Name", "Email Address", "Phone Number",
"Street Address", "City", "Country", "Date of Birth",
]
PROFESSIONAL_FIELDS = [
"Company Name", "Department", "Job Title",
"Office Location", "Employee ID", "Skills", "University",
]
BUSINESS_FIELDS = [
"Counterparty", "Contract Value", "Effective Date", "Jurisdiction",
"Governing Law", "Invoice Number", "Product Name", "Project Name",
]
ALL_PREDEFINED_FIELDS = PERSONAL_FIELDS + PROFESSIONAL_FIELDS + BUSINESS_FIELDS
MODEL_ID = "fastino/gliner2-base-v1"
EXTRACTION_THRESHOLD = 0.4
# ---------------------------------------------------------------------------
# Page config & styles
# ---------------------------------------------------------------------------
st.set_page_config(
page_title="AI Excel Entity Extractor",
page_icon="🔍",
layout="centered",
)
st.html("""
<style>
.stApp { background-color: #fcfcfc; }
div.stButton > button:first-child {
width: 100%;
border-radius: 8px;
height: 3.5em;
background-color: #2563eb;
color: white;
font-weight: bold;
border: none;
}
div.stButton > button:hover { background-color: #1d4ed8; border: none; }
.footer { text-align: center; color: #64748b; font-size: 0.85rem; margin-top: 50px; }
</style>
""")
# ---------------------------------------------------------------------------
# Cached resources & helpers
# ---------------------------------------------------------------------------
@st.cache_resource(show_spinner="Loading AI model…")
def load_model() -> GLiNER2:
return GLiNER2.from_pretrained(MODEL_ID)
@st.cache_data(show_spinner=False)
def load_excel(file) -> pd.DataFrame:
return pd.read_excel(file)
def to_excel_bytes(df: pd.DataFrame) -> bytes:
buf = BytesIO()
with pd.ExcelWriter(buf, engine="openpyxl") as writer:
df.to_excel(writer, index=False)
return buf.getvalue()
def parse_custom_labels(raw: str) -> list[str]:
return [c.strip() for c in raw.split(",") if c.strip()]
def is_valid_text(value: str) -> bool:
return bool(value.strip()) and value.lower() != "nan"
# ---------------------------------------------------------------------------
# UI - Header
# ---------------------------------------------------------------------------
st.title("🔍 AI Excel Entity Extractor")
st.markdown(
"Automatically extract specific entities like Name, Email, etc., "
"from your spreadsheet text using GLiNER2 Zero-Shot AI."
)
# ---------------------------------------------------------------------------
# Step 1: Upload
# ---------------------------------------------------------------------------
st.write("### 1. Source Data")
uploaded_file = st.file_uploader("Upload an Excel file (.xlsx)", type="xlsx")
if not uploaded_file:
st.write("### How it works")
col_a, col_b, col_c = st.columns(3)
with col_a:
st.markdown("**1. Upload**\nDrop an Excel file with a column of text (e.g., emails, descriptions, or notes).")
with col_b:
st.markdown("**2. Define**\nSelect from common entities like Names and Dates, or type your own custom fields.")
with col_c:
st.markdown("**3. Extract**\nThe AI reads every row and creates new columns for every entity it discovers.")
st.stop()
# ---------------------------------------------------------------------------
# Step 2: Configure
# ---------------------------------------------------------------------------
df = load_excel(uploaded_file)
if df.empty:
st.error("The uploaded file appears to be empty. Please upload a file with data.")
st.stop()
row_count = len(df)
st.divider()
st.write("### 2. Configure Extraction")
with st.spinner("Loading configuration…"):
with st.container(border=True):
col_select, col_info = st.columns([2, 1])
with col_select:
text_column = st.selectbox("Select text column to analyze:", df.columns)
with col_info:
st.metric("Total Rows", f"{row_count:,}")
st.write("---")
col1, col2 = st.columns(2)
with col1:
selected_labels = st.multiselect(
"Select Fields to Extract:",
options=ALL_PREDEFINED_FIELDS,
default=["Person Name", "Company Name"],
help="Choose common entities from the library.",
)
with col2:
custom_labels_str = st.text_area(
"Custom Entities (Comma Separated):",
placeholder="e.g. Case Number, Part ID, Deadline",
help="Define unique entities specific to your data.",
)
active_labels = list(dict.fromkeys(selected_labels + parse_custom_labels(custom_labels_str)))
# ---------------------------------------------------------------------------
# Step 3: Extract
# ---------------------------------------------------------------------------
if not st.button("🚀 Extract Fields"):
st.stop()
if not active_labels:
st.warning("⚠️ Please select or define at least one entity to extract.")
st.stop()
model = load_model()
processed_df = df.copy()
for label in active_labels:
processed_df[label] = ""
status = st.empty()
progress_bar = st.progress(0)
start_time = time.time()
for i, row in processed_df.iterrows():
text = str(row[text_column])
if is_valid_text(text):
try:
results = model.extract_entities(text, active_labels, threshold=EXTRACTION_THRESHOLD)
for label, found_list in results.get("entities", {}).items():
processed_df.at[i, label] = ", ".join(found_list)
except Exception as e:
st.warning(f"Row {i + 1} skipped due to an error: {e}")
progress_bar.progress((i + 1) / row_count)
status.text(f"Extracting fields from row {i + 1} of {row_count}…")
duration = round(time.time() - start_time, 1)
progress_bar.empty()
status.empty()
st.success(f"✅ Extraction complete - {row_count:,} rows processed in {duration}s.")
st.write("### 3. Extraction Preview")
st.dataframe(processed_df.head(10), use_container_width=True)
st.download_button(
label="📥 Download Enriched Excel File",
data=to_excel_bytes(processed_df),
file_name="AI_Extracted_Report.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
# ---------------------------------------------------------------------------
# Footer
# ---------------------------------------------------------------------------
st.markdown("---")
st.markdown(
'<div class="footer">Powered by '
'<a href="https://github.com/fastino-ai/GLiNER2" target="_blank">GLiNER2</a>'
" • Open-source Zero-Shot Named Entity Recognition</div>",
unsafe_allow_html=True,
)