File size: 7,063 Bytes
41e2a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import streamlit as st
import torch
import re
import os
import glob
from transformers import AutoTokenizer, AutoModelForCausalLM


MODEL_DIR = "./span_model_generative"


def discover_checkpoints(model_dir, prefix):
    found = {}
    for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])):
        name = f"{prefix} / {path.split('/')[-1]}"
        found[name] = path
    final_path = f"{model_dir}/final"
    if os.path.exists(final_path):
        found[f"{prefix} / final"] = final_path
    return found


CHECKPOINTS = discover_checkpoints(MODEL_DIR, "Gemma-Gen")
if not CHECKPOINTS:
    st.error("No checkpoints found.")
    st.stop()


_current_model = {"path": None, "model": None, "tokenizer": None}


def load_model(checkpoint_path):
    if _current_model["path"] == checkpoint_path:
        return _current_model["tokenizer"], _current_model["model"]
    if _current_model["model"] is not None:
        del _current_model["model"]
        del _current_model["tokenizer"]
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16)
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
    _current_model["path"] = checkpoint_path
    _current_model["model"] = model
    _current_model["tokenizer"] = tokenizer
    return tokenizer, model


def strip_md(text):
    text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
    text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text)
    text = re.sub(r'\*([^*]*)\*', r'\1', text)
    return text


def predict_spans(tokenizer, model, title, text, max_new_tokens=512):
    """Run generative inference and return extracted spans."""
    device = next(model.parameters()).device
    MAX_LEN = 1024
    STRIDE = 256
    COMPLETION_RESERVE = 256

    clean_text = strip_md(text)

    prompt_prefix = f"Title: {title}\n\nText: "
    prompt_suffix = "\n\nHooks:\n"

    prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"]
    suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"]
    text_enc = tokenizer(clean_text, add_special_tokens=False)
    text_ids = text_enc["input_ids"]

    fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids)
    text_budget = MAX_LEN - fixed_overhead - COMPLETION_RESERVE

    if text_budget <= 0:
        return []

    all_spans = []
    start = 0

    while start < len(text_ids):
        end = min(start + text_budget, len(text_ids))
        chunk_text_ids = text_ids[start:end]

        input_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids
        input_tensor = torch.tensor([input_ids], device=device)

        with torch.no_grad():
            output = model.generate(
                input_tensor,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
            )

        generated_ids = output[0][len(input_ids):]
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        if generated_text and generated_text != "[NONE]":
            for line in generated_text.split("\n"):
                line = line.strip()
                if line and line != "[NONE]":
                    all_spans.append(line)

        if end >= len(text_ids):
            break
        start += STRIDE

    # Deduplicate while preserving order
    seen = set()
    unique_spans = []
    for span in all_spans:
        if span not in seen:
            seen.add(span)
            unique_spans.append(span)

    return unique_spans


def highlight_spans_in_text(text, spans):
    """Build (text, is_span) segments by finding span occurrences in the original text."""
    if not spans:
        return [(text, False)]

    # Find all span positions in text
    positions = []
    for span in spans:
        idx = text.find(span)
        if idx >= 0:
            positions.append((idx, idx + len(span), span))

    if not positions:
        return [(text, False)]

    # Sort by start position, resolve overlaps (keep longer)
    positions.sort(key=lambda x: (x[0], -(x[1] - x[0])))
    merged = []
    for s, e, span in positions:
        if merged and s < merged[-1][1]:
            continue
        merged.append((s, e, span))

    # Build segments
    segments = []
    pos = 0
    for s, e, span in merged:
        if pos < s:
            segments.append((text[pos:s], False))
        segments.append((text[s:e], True))
        pos = e
    if pos < len(text):
        segments.append((text[pos:], False))

    return segments


st.set_page_config(page_title="Span Extractor Generative (Gemma)", layout="wide")
st.title("Span Extractor Inference (Gemma) — Generative")

checkpoint_names = list(CHECKPOINTS.keys())
checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1)
tokenizer, model = load_model(CHECKPOINTS[checkpoint])

max_tokens = st.slider("Max generation tokens", 64, 1024, 512, 64)

if st.button("Load Random"):
    import sqlite3
    conn = sqlite3.connect("../attention_hooks.db")
    row = conn.execute("""
        SELECT sc.title, sc.markdown
        FROM scraped_content sc
        JOIN extractive_spans es ON sc.url = es.url
        WHERE es.span_text != '' AND length(sc.markdown) > 500
        ORDER BY RANDOM()
        LIMIT 1
    """).fetchone()
    conn.close()
    if row:
        st.session_state["random_title"] = row[0]
        st.session_state["random_text"] = row[1]

default_title = st.session_state.get("random_title", "")
default_text = st.session_state.get("random_text", "")

title = st.text_input("Title", value=default_title, placeholder="Enter article title...")
text = st.text_area("Text", value=default_text, height=300, placeholder="Enter article text...")

if st.button("Extract Spans") and title and text:
    with st.spinner("Generating..."):
        spans = predict_spans(tokenizer, model, title, text, max_new_tokens=max_tokens)

    if not spans:
        st.warning("No spans predicted.")
    else:
        st.caption(f"{len(spans)} span(s) detected")

        segments = highlight_spans_in_text(text, spans)

        html_parts = []
        for seg, is_span in segments:
            escaped = seg.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")
            if is_span:
                html_parts.append(f'<span style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px;">{escaped}</span>')
            else:
                html_parts.append(escaped)

        html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>'
        st.markdown(html, unsafe_allow_html=True)

        st.divider()
        st.subheader("Extracted Spans")
        import pandas as pd
        df = pd.DataFrame({"span": spans})
        st.dataframe(df, use_container_width=True, hide_index=True)