import streamlit as st import torch import re import os import glob from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_DIR = "./span_model_generative" def discover_checkpoints(model_dir, prefix): found = {} for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])): name = f"{prefix} / {path.split('/')[-1]}" found[name] = path final_path = f"{model_dir}/final" if os.path.exists(final_path): found[f"{prefix} / final"] = final_path return found CHECKPOINTS = discover_checkpoints(MODEL_DIR, "Gemma-Gen") if not CHECKPOINTS: st.error("No checkpoints found.") st.stop() _current_model = {"path": None, "model": None, "tokenizer": None} def load_model(checkpoint_path): if _current_model["path"] == checkpoint_path: return _current_model["tokenizer"], _current_model["model"] if _current_model["model"] is not None: del _current_model["model"] del _current_model["tokenizer"] if torch.cuda.is_available(): torch.cuda.empty_cache() tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16) model.eval() if torch.cuda.is_available(): model = model.cuda() _current_model["path"] = checkpoint_path _current_model["model"] = model _current_model["tokenizer"] = tokenizer return tokenizer, model def strip_md(text): text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text) text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text) text = re.sub(r'\*([^*]*)\*', r'\1', text) return text def predict_spans(tokenizer, model, title, text, max_new_tokens=512): """Run generative inference and return extracted spans.""" device = next(model.parameters()).device MAX_LEN = 1024 STRIDE = 256 COMPLETION_RESERVE = 256 clean_text = strip_md(text) prompt_prefix = f"Title: {title}\n\nText: " prompt_suffix = "\n\nHooks:\n" prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"] suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"] text_enc = tokenizer(clean_text, add_special_tokens=False) text_ids = text_enc["input_ids"] fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) text_budget = MAX_LEN - fixed_overhead - COMPLETION_RESERVE if text_budget <= 0: return [] all_spans = [] start = 0 while start < len(text_ids): end = min(start + text_budget, len(text_ids)) chunk_text_ids = text_ids[start:end] input_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids input_tensor = torch.tensor([input_ids], device=device) with torch.no_grad(): output = model.generate( input_tensor, max_new_tokens=max_new_tokens, do_sample=False, eos_token_id=tokenizer.eos_token_id, ) generated_ids = output[0][len(input_ids):] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() if generated_text and generated_text != "[NONE]": for line in generated_text.split("\n"): line = line.strip() if line and line != "[NONE]": all_spans.append(line) if end >= len(text_ids): break start += STRIDE # Deduplicate while preserving order seen = set() unique_spans = [] for span in all_spans: if span not in seen: seen.add(span) unique_spans.append(span) return unique_spans def highlight_spans_in_text(text, spans): """Build (text, is_span) segments by finding span occurrences in the original text.""" if not spans: return [(text, False)] # Find all span positions in text positions = [] for span in spans: idx = text.find(span) if idx >= 0: positions.append((idx, idx + len(span), span)) if not positions: return [(text, False)] # Sort by start position, resolve overlaps (keep longer) positions.sort(key=lambda x: (x[0], -(x[1] - x[0]))) merged = [] for s, e, span in positions: if merged and s < merged[-1][1]: continue merged.append((s, e, span)) # Build segments segments = [] pos = 0 for s, e, span in merged: if pos < s: segments.append((text[pos:s], False)) segments.append((text[s:e], True)) pos = e if pos < len(text): segments.append((text[pos:], False)) return segments st.set_page_config(page_title="Span Extractor Generative (Gemma)", layout="wide") st.title("Span Extractor Inference (Gemma) — Generative") checkpoint_names = list(CHECKPOINTS.keys()) checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1) tokenizer, model = load_model(CHECKPOINTS[checkpoint]) max_tokens = st.slider("Max generation tokens", 64, 1024, 512, 64) if st.button("Load Random"): import sqlite3 conn = sqlite3.connect("../attention_hooks.db") row = conn.execute(""" SELECT sc.title, sc.markdown FROM scraped_content sc JOIN extractive_spans es ON sc.url = es.url WHERE es.span_text != '' AND length(sc.markdown) > 500 ORDER BY RANDOM() LIMIT 1 """).fetchone() conn.close() if row: st.session_state["random_title"] = row[0] st.session_state["random_text"] = row[1] default_title = st.session_state.get("random_title", "") default_text = st.session_state.get("random_text", "") title = st.text_input("Title", value=default_title, placeholder="Enter article title...") text = st.text_area("Text", value=default_text, height=300, placeholder="Enter article text...") if st.button("Extract Spans") and title and text: with st.spinner("Generating..."): spans = predict_spans(tokenizer, model, title, text, max_new_tokens=max_tokens) if not spans: st.warning("No spans predicted.") else: st.caption(f"{len(spans)} span(s) detected") segments = highlight_spans_in_text(text, spans) html_parts = [] for seg, is_span in segments: escaped = seg.replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "
") if is_span: html_parts.append(f'{escaped}') else: html_parts.append(escaped) html = f'
{"".join(html_parts)}
' st.markdown(html, unsafe_allow_html=True) st.divider() st.subheader("Extracted Spans") import pandas as pd df = pd.DataFrame({"span": spans}) st.dataframe(df, use_container_width=True, hide_index=True)