| import streamlit as st |
| import torch |
| import re |
| import os |
| import glob |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
| MODEL_DIR = "./span_model_generative" |
|
|
|
|
| def discover_checkpoints(model_dir, prefix): |
| found = {} |
| for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])): |
| name = f"{prefix} / {path.split('/')[-1]}" |
| found[name] = path |
| final_path = f"{model_dir}/final" |
| if os.path.exists(final_path): |
| found[f"{prefix} / final"] = final_path |
| return found |
|
|
|
|
| CHECKPOINTS = discover_checkpoints(MODEL_DIR, "Gemma-Gen") |
| if not CHECKPOINTS: |
| st.error("No checkpoints found.") |
| st.stop() |
|
|
|
|
| _current_model = {"path": None, "model": None, "tokenizer": None} |
|
|
|
|
| def load_model(checkpoint_path): |
| if _current_model["path"] == checkpoint_path: |
| return _current_model["tokenizer"], _current_model["model"] |
| if _current_model["model"] is not None: |
| del _current_model["model"] |
| del _current_model["tokenizer"] |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
| model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16) |
| model.eval() |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| _current_model["path"] = checkpoint_path |
| _current_model["model"] = model |
| _current_model["tokenizer"] = tokenizer |
| return tokenizer, model |
|
|
|
|
| def strip_md(text): |
| text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text) |
| text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text) |
| text = re.sub(r'\*([^*]*)\*', r'\1', text) |
| return text |
|
|
|
|
| def predict_spans(tokenizer, model, title, text, max_new_tokens=512): |
| """Run generative inference and return extracted spans.""" |
| device = next(model.parameters()).device |
| MAX_LEN = 1024 |
| STRIDE = 256 |
| COMPLETION_RESERVE = 256 |
|
|
| clean_text = strip_md(text) |
|
|
| prompt_prefix = f"Title: {title}\n\nText: " |
| prompt_suffix = "\n\nHooks:\n" |
|
|
| prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"] |
| suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"] |
| text_enc = tokenizer(clean_text, add_special_tokens=False) |
| text_ids = text_enc["input_ids"] |
|
|
| fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) |
| text_budget = MAX_LEN - fixed_overhead - COMPLETION_RESERVE |
|
|
| if text_budget <= 0: |
| return [] |
|
|
| all_spans = [] |
| start = 0 |
|
|
| while start < len(text_ids): |
| end = min(start + text_budget, len(text_ids)) |
| chunk_text_ids = text_ids[start:end] |
|
|
| input_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids |
| input_tensor = torch.tensor([input_ids], device=device) |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| input_tensor, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| generated_ids = output[0][len(input_ids):] |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
| if generated_text and generated_text != "[NONE]": |
| for line in generated_text.split("\n"): |
| line = line.strip() |
| if line and line != "[NONE]": |
| all_spans.append(line) |
|
|
| if end >= len(text_ids): |
| break |
| start += STRIDE |
|
|
| |
| seen = set() |
| unique_spans = [] |
| for span in all_spans: |
| if span not in seen: |
| seen.add(span) |
| unique_spans.append(span) |
|
|
| return unique_spans |
|
|
|
|
| def highlight_spans_in_text(text, spans): |
| """Build (text, is_span) segments by finding span occurrences in the original text.""" |
| if not spans: |
| return [(text, False)] |
|
|
| |
| positions = [] |
| for span in spans: |
| idx = text.find(span) |
| if idx >= 0: |
| positions.append((idx, idx + len(span), span)) |
|
|
| if not positions: |
| return [(text, False)] |
|
|
| |
| positions.sort(key=lambda x: (x[0], -(x[1] - x[0]))) |
| merged = [] |
| for s, e, span in positions: |
| if merged and s < merged[-1][1]: |
| continue |
| merged.append((s, e, span)) |
|
|
| |
| segments = [] |
| pos = 0 |
| for s, e, span in merged: |
| if pos < s: |
| segments.append((text[pos:s], False)) |
| segments.append((text[s:e], True)) |
| pos = e |
| if pos < len(text): |
| segments.append((text[pos:], False)) |
|
|
| return segments |
|
|
|
|
| st.set_page_config(page_title="Span Extractor Generative (Gemma)", layout="wide") |
| st.title("Span Extractor Inference (Gemma) — Generative") |
|
|
| checkpoint_names = list(CHECKPOINTS.keys()) |
| checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1) |
| tokenizer, model = load_model(CHECKPOINTS[checkpoint]) |
|
|
| max_tokens = st.slider("Max generation tokens", 64, 1024, 512, 64) |
|
|
| if st.button("Load Random"): |
| import sqlite3 |
| conn = sqlite3.connect("../attention_hooks.db") |
| row = conn.execute(""" |
| SELECT sc.title, sc.markdown |
| FROM scraped_content sc |
| JOIN extractive_spans es ON sc.url = es.url |
| WHERE es.span_text != '' AND length(sc.markdown) > 500 |
| ORDER BY RANDOM() |
| LIMIT 1 |
| """).fetchone() |
| conn.close() |
| if row: |
| st.session_state["random_title"] = row[0] |
| st.session_state["random_text"] = row[1] |
|
|
| default_title = st.session_state.get("random_title", "") |
| default_text = st.session_state.get("random_text", "") |
|
|
| title = st.text_input("Title", value=default_title, placeholder="Enter article title...") |
| text = st.text_area("Text", value=default_text, height=300, placeholder="Enter article text...") |
|
|
| if st.button("Extract Spans") and title and text: |
| with st.spinner("Generating..."): |
| spans = predict_spans(tokenizer, model, title, text, max_new_tokens=max_tokens) |
|
|
| if not spans: |
| st.warning("No spans predicted.") |
| else: |
| st.caption(f"{len(spans)} span(s) detected") |
|
|
| segments = highlight_spans_in_text(text, spans) |
|
|
| html_parts = [] |
| for seg, is_span in segments: |
| escaped = seg.replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "<br>") |
| if is_span: |
| html_parts.append(f'<span style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px;">{escaped}</span>') |
| else: |
| html_parts.append(escaped) |
|
|
| html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>' |
| st.markdown(html, unsafe_allow_html=True) |
|
|
| st.divider() |
| st.subheader("Extracted Spans") |
| import pandas as pd |
| df = pd.DataFrame({"span": spans}) |
| st.dataframe(df, use_container_width=True, hide_index=True) |
|
|