latent-entity / app.py
dejanseo's picture
Upload 3 files
41e2a73 verified
import streamlit as st
import torch
import re
import os
import glob
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_DIR = "./span_model_generative"
def discover_checkpoints(model_dir, prefix):
found = {}
for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])):
name = f"{prefix} / {path.split('/')[-1]}"
found[name] = path
final_path = f"{model_dir}/final"
if os.path.exists(final_path):
found[f"{prefix} / final"] = final_path
return found
CHECKPOINTS = discover_checkpoints(MODEL_DIR, "Gemma-Gen")
if not CHECKPOINTS:
st.error("No checkpoints found.")
st.stop()
_current_model = {"path": None, "model": None, "tokenizer": None}
def load_model(checkpoint_path):
if _current_model["path"] == checkpoint_path:
return _current_model["tokenizer"], _current_model["model"]
if _current_model["model"] is not None:
del _current_model["model"]
del _current_model["tokenizer"]
if torch.cuda.is_available():
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
_current_model["path"] = checkpoint_path
_current_model["model"] = model
_current_model["tokenizer"] = tokenizer
return tokenizer, model
def strip_md(text):
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text)
text = re.sub(r'\*([^*]*)\*', r'\1', text)
return text
def predict_spans(tokenizer, model, title, text, max_new_tokens=512):
"""Run generative inference and return extracted spans."""
device = next(model.parameters()).device
MAX_LEN = 1024
STRIDE = 256
COMPLETION_RESERVE = 256
clean_text = strip_md(text)
prompt_prefix = f"Title: {title}\n\nText: "
prompt_suffix = "\n\nHooks:\n"
prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"]
suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"]
text_enc = tokenizer(clean_text, add_special_tokens=False)
text_ids = text_enc["input_ids"]
fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids)
text_budget = MAX_LEN - fixed_overhead - COMPLETION_RESERVE
if text_budget <= 0:
return []
all_spans = []
start = 0
while start < len(text_ids):
end = min(start + text_budget, len(text_ids))
chunk_text_ids = text_ids[start:end]
input_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids
input_tensor = torch.tensor([input_ids], device=device)
with torch.no_grad():
output = model.generate(
input_tensor,
max_new_tokens=max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
)
generated_ids = output[0][len(input_ids):]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
if generated_text and generated_text != "[NONE]":
for line in generated_text.split("\n"):
line = line.strip()
if line and line != "[NONE]":
all_spans.append(line)
if end >= len(text_ids):
break
start += STRIDE
# Deduplicate while preserving order
seen = set()
unique_spans = []
for span in all_spans:
if span not in seen:
seen.add(span)
unique_spans.append(span)
return unique_spans
def highlight_spans_in_text(text, spans):
"""Build (text, is_span) segments by finding span occurrences in the original text."""
if not spans:
return [(text, False)]
# Find all span positions in text
positions = []
for span in spans:
idx = text.find(span)
if idx >= 0:
positions.append((idx, idx + len(span), span))
if not positions:
return [(text, False)]
# Sort by start position, resolve overlaps (keep longer)
positions.sort(key=lambda x: (x[0], -(x[1] - x[0])))
merged = []
for s, e, span in positions:
if merged and s < merged[-1][1]:
continue
merged.append((s, e, span))
# Build segments
segments = []
pos = 0
for s, e, span in merged:
if pos < s:
segments.append((text[pos:s], False))
segments.append((text[s:e], True))
pos = e
if pos < len(text):
segments.append((text[pos:], False))
return segments
st.set_page_config(page_title="Span Extractor Generative (Gemma)", layout="wide")
st.title("Span Extractor Inference (Gemma) — Generative")
checkpoint_names = list(CHECKPOINTS.keys())
checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1)
tokenizer, model = load_model(CHECKPOINTS[checkpoint])
max_tokens = st.slider("Max generation tokens", 64, 1024, 512, 64)
if st.button("Load Random"):
import sqlite3
conn = sqlite3.connect("../attention_hooks.db")
row = conn.execute("""
SELECT sc.title, sc.markdown
FROM scraped_content sc
JOIN extractive_spans es ON sc.url = es.url
WHERE es.span_text != '' AND length(sc.markdown) > 500
ORDER BY RANDOM()
LIMIT 1
""").fetchone()
conn.close()
if row:
st.session_state["random_title"] = row[0]
st.session_state["random_text"] = row[1]
default_title = st.session_state.get("random_title", "")
default_text = st.session_state.get("random_text", "")
title = st.text_input("Title", value=default_title, placeholder="Enter article title...")
text = st.text_area("Text", value=default_text, height=300, placeholder="Enter article text...")
if st.button("Extract Spans") and title and text:
with st.spinner("Generating..."):
spans = predict_spans(tokenizer, model, title, text, max_new_tokens=max_tokens)
if not spans:
st.warning("No spans predicted.")
else:
st.caption(f"{len(spans)} span(s) detected")
segments = highlight_spans_in_text(text, spans)
html_parts = []
for seg, is_span in segments:
escaped = seg.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")
if is_span:
html_parts.append(f'<span style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px;">{escaped}</span>')
else:
html_parts.append(escaped)
html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>'
st.markdown(html, unsafe_allow_html=True)
st.divider()
st.subheader("Extracted Spans")
import pandas as pd
df = pd.DataFrame({"span": spans})
st.dataframe(df, use_container_width=True, hide_index=True)