rasmodev's picture
Update app.py
bde4cfe verified
"""
ValuationAI β€” Nairobi Valuation Sheet OCR
Model: rasmodev/Handwriting_trocr_model
"""
import io, time, logging, tempfile, os
import streamlit as st
import pandas as pd
from PIL import Image
st.set_page_config(
page_title="ValuationAI",
page_icon="πŸ“‹",
layout="wide",
initial_sidebar_state="collapsed",
)
logging.basicConfig(level=logging.INFO)
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Cormorant+Garamond:ital,wght@0,600;0,700;1,600&family=Inter:wght@300;400;500;600&display=swap');
html, body, [class*="css"], .stApp {
font-family: 'Inter', sans-serif;
background: #F8F7F4;
color: #1A1A2E;
}
.block-container {
padding: 3rem 4rem !important;
max-width: 1100px !important;
}
#MainMenu, footer, header { visibility: hidden; }
.topbar {
display: flex; align-items: flex-end;
justify-content: space-between;
padding-bottom: 2rem; margin-bottom: 3rem;
border-bottom: 2px solid #1A1A2E;
}
.logo { font-family: 'Cormorant Garamond', serif; font-size: 1.8rem; font-weight: 700; color: #1A1A2E; letter-spacing: -0.02em; line-height: 1; }
.logo span { color: #2563EB; }
.logo-sub { font-size: 0.68rem; font-weight: 500; letter-spacing: 0.15em; text-transform: uppercase; color: #9CA3AF; margin-top: 0.3rem; }
.model-ref { font-size: 0.7rem; color: #9CA3AF; font-weight: 400; letter-spacing: 0.04em; text-align: right; }
.model-ref strong { color: #2563EB; font-weight: 600; }
.headline { font-family: 'Cormorant Garamond', serif; font-size: 3.4rem; font-weight: 700; line-height: 1.08; letter-spacing: -0.03em; color: #1A1A2E; margin-bottom: 1rem; max-width: 700px; }
.headline em { font-style: italic; color: #2563EB; }
.subline { font-size: 0.95rem; font-weight: 300; color: #6B7280; line-height: 1.7; max-width: 500px; margin-bottom: 3rem; }
.step { font-size: 0.65rem; font-weight: 700; letter-spacing: 0.18em; text-transform: uppercase; color: #2563EB; margin-bottom: 0.5rem; }
[data-testid="stFileUploader"] section {
background: #fff !important;
border: 2px dashed #D1D5DB !important;
border-radius: 12px !important;
padding: 2.5rem !important;
transition: all 0.2s ease !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.04) !important;
}
[data-testid="stFileUploader"] section:hover {
border-color: #2563EB !important;
box-shadow: 0 0 0 4px rgba(37,99,235,0.06) !important;
}
.fchip { display: inline-flex; align-items: center; gap: 5px; background: #EFF6FF; border: 1px solid #BFDBFE; color: #1D4ED8; padding: 0.25rem 0.7rem; border-radius: 6px; font-size: 0.73rem; font-weight: 500; margin: 2px; }
.stButton > button {
background: #1A1A2E !important; color: #fff !important; border: none !important;
border-radius: 8px !important; padding: 0.85rem 2.5rem !important;
font-family: 'Inter', sans-serif !important; font-size: 0.88rem !important;
font-weight: 600 !important; letter-spacing: 0.04em !important;
text-transform: uppercase !important; transition: all 0.2s !important;
box-shadow: 0 2px 8px rgba(26,26,46,0.2) !important; width: 100% !important;
}
.stButton > button:hover { background: #2563EB !important; box-shadow: 0 4px 16px rgba(37,99,235,0.3) !important; transform: translateY(-1px) !important; }
.stButton > button:disabled { background: #E5E7EB !important; color: #9CA3AF !important; box-shadow: none !important; transform: none !important; }
.stProgress > div > div > div { background: #2563EB !important; border-radius: 4px !important; }
.stProgress > div > div { background: #E5E7EB !important; border-radius: 4px !important; height: 4px !important; }
.stats-strip { display: flex; background: #1A1A2E; border-radius: 12px; overflow: hidden; margin: 2.5rem 0 2rem; }
.stat-item { flex: 1; padding: 1.6rem 2rem; border-right: 1px solid rgba(255,255,255,0.08); }
.stat-item:last-child { border-right: none; }
.stat-n { font-family: 'Cormorant Garamond', serif; font-size: 2.6rem; font-weight: 700; color: #fff; line-height: 1; margin-bottom: 0.3rem; }
.stat-l { font-size: 0.68rem; font-weight: 500; letter-spacing: 0.12em; text-transform: uppercase; color: #6B7280; }
.section-head { display: flex; align-items: center; justify-content: space-between; margin-bottom: 1rem; padding-bottom: 0.75rem; border-bottom: 1px solid #E5E7EB; }
.section-title { font-family: 'Cormorant Garamond', serif; font-size: 1.5rem; font-weight: 600; color: #1A1A2E; }
div[data-testid="stDownloadButton"] > button {
background: #fff !important; border: 1.5px solid #1A1A2E !important; color: #1A1A2E !important;
border-radius: 8px !important; padding: 0.6rem 1.4rem !important;
font-family: 'Inter', sans-serif !important; font-weight: 600 !important;
font-size: 0.82rem !important; letter-spacing: 0.04em !important;
text-transform: uppercase !important; transition: all 0.2s !important; width: auto !important;
}
div[data-testid="stDownloadButton"] > button:hover { background: #1A1A2E !important; color: #fff !important; }
[data-testid="stDataFrame"] { border-radius: 10px !important; border: 1px solid #E5E7EB !important; overflow: hidden !important; box-shadow: 0 1px 4px rgba(0,0,0,0.05) !important; }
</style>
""", unsafe_allow_html=True)
# ═══════════════════════════════════════════════════════════
# MODEL
# ═══════════════════════════════════════════════════════════
@st.cache_resource(show_spinner="Loading recognition model…")
def load_model():
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
MODEL_ID = "rasmodev/Handwriting_trocr_model"
BASE_ID = "microsoft/trocr-base-handwritten"
# Load processor from base model β€” has all required config files
# Load weights from fine-tuned model β€” contains trained parameters
processor = TrOCRProcessor.from_pretrained(BASE_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
return processor, model, device
# ═══════════════════════════════════════════════════════════
# OCR
# ═══════════════════════════════════════════════════════════
def ocr_page(img: Image.Image) -> str:
import torch
processor, model, device = load_model()
pixel_values = processor(
images=img.convert("RGB"),
return_tensors="pt"
).pixel_values.to(device)
with torch.no_grad():
generated = model.generate(
pixel_values=pixel_values,
max_new_tokens=64,
num_beams=1,
)
return processor.batch_decode(generated, skip_special_tokens=True)[0].strip()
# ═══════════════════════════════════════════════════════════
# PARSE LABEL
# Format: PLOT: ... | LOC: ... | AREA: ... | AMT: ... | DATE: ... | VOS: ...
# ═══════════════════════════════════════════════════════════
def parse_label(raw_text: str, filename: str) -> dict:
record = {
"File": filename,
"Plot Number": "",
"Location": "",
"Area": "",
"Amount (KES)": None,
"Date": "",
"VOS": "",
"Raw Output": raw_text,
}
for part in raw_text.split("|"):
part = part.strip()
if ":" not in part:
continue
key, _, val = part.partition(":")
key = key.strip().upper()
val = val.strip()
if key == "PLOT":
record["Plot Number"] = val
elif key == "LOC":
record["Location"] = val
elif key == "AREA":
record["Area"] = val
elif key == "AMT":
try:
record["Amount (KES)"] = int(val.replace(",", "").replace(" ", ""))
except ValueError:
record["Amount (KES)"] = val
elif key == "DATE":
record["Date"] = val
elif key == "VOS":
record["VOS"] = val
return record
# ═══════════════════════════════════════════════════════════
# EXCEL EXPORT
# ═══════════════════════════════════════════════════════════
def make_excel(records: list) -> bytes:
from openpyxl import load_workbook
from openpyxl.styles import Font, PatternFill, Alignment
from openpyxl.utils import get_column_letter
clean = [{k: v for k, v in r.items() if k != "Raw Output"} for r in records]
buf = io.BytesIO()
pd.DataFrame(clean).to_excel(buf, index=False, sheet_name="Valuation Data")
buf.seek(0)
wb = load_workbook(buf)
ws = wb.active
hdr = PatternFill("solid", start_color="1A1A2E")
for ci, cell in enumerate(ws[1], 1):
cell.font = Font(name="Calibri", bold=True, color="FFFFFF", size=11)
cell.fill = hdr
cell.alignment = Alignment(horizontal="center", vertical="center")
ws.column_dimensions[get_column_letter(ci)].width = 26
ws.row_dimensions[1].height = 30
for row in ws.iter_rows(min_row=2):
for cell in row:
cell.alignment = Alignment(vertical="center", wrap_text=True)
if cell.row % 2 == 0:
cell.fill = PatternFill("solid", start_color="F0F4FF")
ws.freeze_panes = "A2"
out = io.BytesIO()
wb.save(out)
return out.getvalue()
# ═══════════════════════════════════════════════════════════
# SESSION STATE
# ═══════════════════════════════════════════════════════════
for k, v in [("records",[]),("excel",None),("done",False),("errors",[])]:
if k not in st.session_state:
st.session_state[k] = v
# ═══════════════════════════════════════════════════════════
# UI
# ═══════════════════════════════════════════════════════════
st.markdown("""
<div class="topbar">
<div>
<div class="logo">Valuation<span>AI</span></div>
<div class="logo-sub">Nairobi City County β€” Document Intelligence</div>
</div>
<div class="model-ref">
Recognition model<br>
<strong>rasmodev/Handwriting_trocr_model</strong>
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div class="headline">
Digitise handwritten<br>valuation sheets <em>instantly.</em>
</div>
<div class="subline">
Upload one or more scanned PDF valuation sheets.
The system reads every handwritten field and delivers
a structured Excel file β€” ready for records management.
</div>
""", unsafe_allow_html=True)
st.markdown('<div class="step">Step 1 β€” Upload Documents</div>', unsafe_allow_html=True)
uploaded = st.file_uploader(
"Drag and drop valuation sheet PDFs here, or click to browse",
type=["pdf", "png", "jpg", "jpeg", "tiff", "bmp"],
accept_multiple_files=True,
label_visibility="collapsed",
)
if uploaded:
chips = "".join(
f'<span class="fchip">πŸ“„ {f.name[:35]}{"…" if len(f.name)>35 else ""}</span>'
for f in uploaded
)
st.markdown(f'<div style="margin-top:0.6rem">{chips}</div>', unsafe_allow_html=True)
st.markdown('<div style="height:1.2rem"></div>', unsafe_allow_html=True)
st.markdown('<div class="step">Step 2 β€” Extract & Download</div>', unsafe_allow_html=True)
run = st.button(
"Extract Data from Documents",
disabled=not uploaded,
use_container_width=True,
)
# ═══════════════════════════════════════════════════════════
# PROCESSING
# ═══════════════════════════════════════════════════════════
if run and uploaded:
import fitz, traceback
st.session_state.records = []
st.session_state.errors = []
st.session_state.done = False
bar = st.progress(0.0)
status = st.empty()
t0 = time.time()
for fi, uf in enumerate(uploaded):
fname = uf.name
raw = uf.read()
bar.progress(fi / len(uploaded), text=f"Reading {fname}…")
st.write(f"πŸ“„ **{fname}** β€” {len(raw):,} bytes")
try:
ext = fname.lower().rsplit(".", 1)[-1]
if ext == "pdf":
# Write to temp file β€” same as training
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
doc = fitz.open(tmp_path)
st.write(f" βœ… PDF opened β€” {len(doc)} page(s) found")
imgs = []
mat = fitz.Matrix(200/72, 200/72)
for page in doc:
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.open(io.BytesIO(pix.tobytes("png"))).convert("RGB")
imgs.append(img)
pix = None
doc.close()
os.unlink(tmp_path)
st.write(f" βœ… Rasterized {len(imgs)} page image(s)")
else:
imgs = [Image.open(io.BytesIO(raw)).convert("RGB")]
st.write(f" βœ… Loaded image")
if not imgs:
st.error(f" ❌ No pages extracted from {fname}")
st.session_state.errors.append(f"{fname}: no pages extracted")
continue
for pi, img in enumerate(imgs, 1):
status.caption(f"Running OCR on **{fname}** β€” page {pi} of {len(imgs)}")
raw_text = ocr_page(img)
st.write(f" πŸ“ Page {pi} OCR output: `{raw_text}`")
record = parse_label(raw_text, fname)
st.session_state.records.append(record)
except Exception as e:
st.error(f"❌ Error on {fname}: {e}")
st.code(traceback.format_exc())
st.session_state.errors.append(f"{fname}: {e}")
bar.progress((fi + 1) / len(uploaded))
bar.empty()
status.empty()
if st.session_state.records:
st.session_state.excel = make_excel(st.session_state.records)
st.session_state.done = True
elapsed = time.time() - t0
st.success(
f"Processed {len(st.session_state.records)} page(s) "
f"from {len(uploaded)} document(s) in {elapsed:.1f}s."
)
# ═══════════════════════════════════════════════════════════
# RESULTS
# ═══════════════════════════════════════════════════════════
if st.session_state.done and st.session_state.records:
records = st.session_state.records
df = pd.DataFrame(records)
display_cols = [c for c in df.columns if c != "Raw Output"]
df_display = df[display_cols]
n_plots = df["Plot Number"].astype(bool).sum()
n_amounts = pd.to_numeric(df["Amount (KES)"], errors="coerce").notna().sum()
n_dates = df["Date"].astype(bool).sum()
st.markdown(f"""
<div class="stats-strip">
<div class="stat-item"><div class="stat-n">{len(records)}</div><div class="stat-l">Pages processed</div></div>
<div class="stat-item"><div class="stat-n">{n_plots}</div><div class="stat-l">Plot numbers</div></div>
<div class="stat-item"><div class="stat-n">{n_amounts}</div><div class="stat-l">Amounts extracted</div></div>
<div class="stat-item"><div class="stat-n">{n_dates}</div><div class="stat-l">Dates captured</div></div>
</div>
""", unsafe_allow_html=True)
col_t, col_d = st.columns([5, 1])
with col_t:
st.markdown('<div class="section-head"><div class="section-title">Extracted Records</div></div>', unsafe_allow_html=True)
with col_d:
st.markdown('<div style="padding-top:0.3rem"></div>', unsafe_allow_html=True)
if st.session_state.excel:
st.download_button(
"⬇ Export Excel",
data=st.session_state.excel,
file_name="valuation_records.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
st.dataframe(df_display, use_container_width=True,
height=min(80 + len(df)*38, 560), hide_index=True)
with st.expander("πŸ” Raw model output (for verification)"):
for r in records:
st.markdown(
f'<div style="font-family:monospace;font-size:0.78rem;'
f'padding:0.5rem 0;border-bottom:1px solid #E5E7EB;color:#374151">'
f'<strong>{r["File"]}</strong><br>{r.get("Raw Output","")}</div>',
unsafe_allow_html=True,
)
if st.session_state.errors:
with st.expander(f"⚠ {len(st.session_state.errors)} file(s) could not be processed"):
for e in st.session_state.errors:
st.caption(e)