File size: 7,063 Bytes
41e2a73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import streamlit as st
import torch
import re
import os
import glob
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_DIR = "./span_model_generative"
def discover_checkpoints(model_dir, prefix):
found = {}
for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])):
name = f"{prefix} / {path.split('/')[-1]}"
found[name] = path
final_path = f"{model_dir}/final"
if os.path.exists(final_path):
found[f"{prefix} / final"] = final_path
return found
CHECKPOINTS = discover_checkpoints(MODEL_DIR, "Gemma-Gen")
if not CHECKPOINTS:
st.error("No checkpoints found.")
st.stop()
_current_model = {"path": None, "model": None, "tokenizer": None}
def load_model(checkpoint_path):
if _current_model["path"] == checkpoint_path:
return _current_model["tokenizer"], _current_model["model"]
if _current_model["model"] is not None:
del _current_model["model"]
del _current_model["tokenizer"]
if torch.cuda.is_available():
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
_current_model["path"] = checkpoint_path
_current_model["model"] = model
_current_model["tokenizer"] = tokenizer
return tokenizer, model
def strip_md(text):
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text)
text = re.sub(r'\*([^*]*)\*', r'\1', text)
return text
def predict_spans(tokenizer, model, title, text, max_new_tokens=512):
"""Run generative inference and return extracted spans."""
device = next(model.parameters()).device
MAX_LEN = 1024
STRIDE = 256
COMPLETION_RESERVE = 256
clean_text = strip_md(text)
prompt_prefix = f"Title: {title}\n\nText: "
prompt_suffix = "\n\nHooks:\n"
prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"]
suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"]
text_enc = tokenizer(clean_text, add_special_tokens=False)
text_ids = text_enc["input_ids"]
fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids)
text_budget = MAX_LEN - fixed_overhead - COMPLETION_RESERVE
if text_budget <= 0:
return []
all_spans = []
start = 0
while start < len(text_ids):
end = min(start + text_budget, len(text_ids))
chunk_text_ids = text_ids[start:end]
input_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids
input_tensor = torch.tensor([input_ids], device=device)
with torch.no_grad():
output = model.generate(
input_tensor,
max_new_tokens=max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
)
generated_ids = output[0][len(input_ids):]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
if generated_text and generated_text != "[NONE]":
for line in generated_text.split("\n"):
line = line.strip()
if line and line != "[NONE]":
all_spans.append(line)
if end >= len(text_ids):
break
start += STRIDE
# Deduplicate while preserving order
seen = set()
unique_spans = []
for span in all_spans:
if span not in seen:
seen.add(span)
unique_spans.append(span)
return unique_spans
def highlight_spans_in_text(text, spans):
"""Build (text, is_span) segments by finding span occurrences in the original text."""
if not spans:
return [(text, False)]
# Find all span positions in text
positions = []
for span in spans:
idx = text.find(span)
if idx >= 0:
positions.append((idx, idx + len(span), span))
if not positions:
return [(text, False)]
# Sort by start position, resolve overlaps (keep longer)
positions.sort(key=lambda x: (x[0], -(x[1] - x[0])))
merged = []
for s, e, span in positions:
if merged and s < merged[-1][1]:
continue
merged.append((s, e, span))
# Build segments
segments = []
pos = 0
for s, e, span in merged:
if pos < s:
segments.append((text[pos:s], False))
segments.append((text[s:e], True))
pos = e
if pos < len(text):
segments.append((text[pos:], False))
return segments
st.set_page_config(page_title="Span Extractor Generative (Gemma)", layout="wide")
st.title("Span Extractor Inference (Gemma) — Generative")
checkpoint_names = list(CHECKPOINTS.keys())
checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1)
tokenizer, model = load_model(CHECKPOINTS[checkpoint])
max_tokens = st.slider("Max generation tokens", 64, 1024, 512, 64)
if st.button("Load Random"):
import sqlite3
conn = sqlite3.connect("../attention_hooks.db")
row = conn.execute("""
SELECT sc.title, sc.markdown
FROM scraped_content sc
JOIN extractive_spans es ON sc.url = es.url
WHERE es.span_text != '' AND length(sc.markdown) > 500
ORDER BY RANDOM()
LIMIT 1
""").fetchone()
conn.close()
if row:
st.session_state["random_title"] = row[0]
st.session_state["random_text"] = row[1]
default_title = st.session_state.get("random_title", "")
default_text = st.session_state.get("random_text", "")
title = st.text_input("Title", value=default_title, placeholder="Enter article title...")
text = st.text_area("Text", value=default_text, height=300, placeholder="Enter article text...")
if st.button("Extract Spans") and title and text:
with st.spinner("Generating..."):
spans = predict_spans(tokenizer, model, title, text, max_new_tokens=max_tokens)
if not spans:
st.warning("No spans predicted.")
else:
st.caption(f"{len(spans)} span(s) detected")
segments = highlight_spans_in_text(text, spans)
html_parts = []
for seg, is_span in segments:
escaped = seg.replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "<br>")
if is_span:
html_parts.append(f'<span style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px;">{escaped}</span>')
else:
html_parts.append(escaped)
html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>'
st.markdown(html, unsafe_allow_html=True)
st.divider()
st.subheader("Extracted Spans")
import pandas as pd
df = pd.DataFrame({"span": spans})
st.dataframe(df, use_container_width=True, hide_index=True)
|