latent-entity / inference.py
dejanseo's picture
Upload inference.py
5b9c2ac verified
import streamlit as st
import torch
import re
import os
from transformers import AutoTokenizer, AutoModelForTokenClassification
LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
import glob
MODEL_DIRS = {
"CE": "./span_model_ce",
"Focal": "./span_model_focal",
}
def discover_checkpoints(model_dir, prefix):
found = {}
for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])):
name = f"{prefix} / {path.split('/')[-1]}"
found[name] = path
final_path = f"{model_dir}/final"
if os.path.exists(final_path):
found[f"{prefix} / final"] = final_path
return found
CHECKPOINTS = {}
for prefix, model_dir in MODEL_DIRS.items():
CHECKPOINTS.update(discover_checkpoints(model_dir, prefix))
if not CHECKPOINTS:
st.error("No checkpoints found.")
st.stop()
_current_model = {"path": None, "model": None, "tokenizer": None}
def load_model(checkpoint_path):
if _current_model["path"] == checkpoint_path:
return _current_model["tokenizer"], _current_model["model"]
# Free old model
if _current_model["model"] is not None:
del _current_model["model"]
del _current_model["tokenizer"]
if torch.cuda.is_available():
torch.cuda.empty_cache()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForTokenClassification.from_pretrained(checkpoint_path)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
_current_model["path"] = checkpoint_path
_current_model["model"] = model
_current_model["tokenizer"] = tokenizer
return tokenizer, model
def strip_md(text):
text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text)
text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text)
text = re.sub(r'\*([^*]*)\*', r'\1', text)
return text
def build_clean_to_original_map(original, cleaned):
"""Build character mapping from cleaned text positions back to original text positions."""
# Align cleaned to original using simple forward matching
mapping = []
j = 0
for i, ch in enumerate(cleaned):
while j < len(original) and original[j] != ch:
j += 1
mapping.append(j)
j += 1
return mapping
def predict_spans(tokenizer, model, title, text, threshold=0.5):
"""Run inference and return list of (text, is_span) tuples for rendering."""
device = next(model.parameters()).device
# Strip markdown for model input, keep original for display
clean_text = strip_md(text)
# Tokenize title and cleaned text
title_enc = tokenizer(title, add_special_tokens=False)
text_enc = tokenizer(clean_text, add_special_tokens=False, return_offsets_mapping=True)
title_ids = title_enc["input_ids"]
text_ids = text_enc["input_ids"]
text_offsets = text_enc["offset_mapping"]
# Build input: [CLS] title [SEP] text [SEP]
input_ids = [tokenizer.cls_token_id] + title_ids + [tokenizer.sep_token_id] + text_ids + [tokenizer.sep_token_id]
attention_mask = [1] * len(input_ids)
# Truncate to model max length
max_len = tokenizer.model_max_length
if max_len > 10000:
max_len = 512
input_ids = input_ids[:max_len]
attention_mask = attention_mask[:max_len]
text_start = len(title_ids) + 2 # CLS + title + SEP
text_end = len(input_ids) - 1 # before final SEP
inputs = {
"input_ids": torch.tensor([input_ids], device=device),
"attention_mask": torch.tensor([attention_mask], device=device),
}
with torch.no_grad():
logits = model(**inputs).logits[0] # (seq_len, 3)
probs = torch.softmax(logits, dim=-1)
# Map token probs from clean text back to original text
clean_to_orig = build_clean_to_original_map(text, clean_text)
char_labels = [0] * len(text)
char_probs = [0.0] * len(text)
all_char_probs = [0.0] * len(text)
tokens_used = min(len(text_ids), text_end - text_start)
for i in range(tokens_used):
tok_idx = text_start + i
if tok_idx >= len(probs):
break
span_prob = (probs[tok_idx][LABEL2ID["B-SPAN"]] + probs[tok_idx][LABEL2ID["I-SPAN"]]).item()
if i < len(text_offsets):
clean_start, clean_end = text_offsets[i]
for cc in range(clean_start, min(clean_end, len(clean_text))):
if cc < len(clean_to_orig):
oc = clean_to_orig[cc]
if oc < len(text):
all_char_probs[oc] = max(all_char_probs[oc], span_prob)
if span_prob >= threshold:
for cc in range(clean_start, min(clean_end, len(clean_text))):
if cc < len(clean_to_orig):
oc = clean_to_orig[cc]
if oc < len(text):
char_labels[oc] = 1
char_probs[oc] = max(char_probs[oc], span_prob)
# Expand labeled chars to cover full words (fix subword splits)
# A "word" is a run of non-whitespace characters
i = 0
while i < len(text):
if text[i].isspace():
i += 1
continue
# Find word boundary
word_start = i
while i < len(text) and not text[i].isspace():
i += 1
word_end = i
# If any char in this word is labeled, label the whole word
if any(char_labels[c] for c in range(word_start, word_end)):
max_prob = max(char_probs[c] for c in range(word_start, word_end))
for c in range(word_start, word_end):
char_labels[c] = 1
char_probs[c] = max(char_probs[c], max_prob)
# Build segments with average confidence per span
segments = []
if not text:
return segments
current_label = char_labels[0]
current_start = 0
for i in range(1, len(text)):
if char_labels[i] != current_label:
conf = sum(char_probs[current_start:i]) / max(1, i - current_start) if current_label == 1 else 0.0
segments.append((text[current_start:i], current_label == 1, conf))
current_start = i
current_label = char_labels[i]
conf = sum(char_probs[current_start:]) / max(1, len(text) - current_start) if current_label == 1 else 0.0
segments.append((text[current_start:], current_label == 1, conf))
return segments, all_char_probs
st.set_page_config(page_title="Span Extractor", layout="wide")
st.title("Span Extractor Inference")
checkpoint_names = list(CHECKPOINTS.keys())
checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1)
tokenizer, model = load_model(CHECKPOINTS[checkpoint])
threshold = st.slider("Span confidence threshold", 0.0, 1.0, 0.5, 0.05)
title = st.text_input("Title", placeholder="Enter article title...")
text = st.text_area("Text", height=300, placeholder="Enter article text...")
if st.button("Extract Spans") and title and text:
segments, all_char_probs = predict_spans(tokenizer, model, title, text, threshold)
if not any(is_span for _, is_span, _ in segments):
st.warning("No spans predicted.")
else:
span_count = sum(1 for seg, is_span, _ in segments if is_span)
st.caption(f"{span_count} span(s) detected")
# Render with green background for spans, tooltips on all words
html_parts = []
pos = 0
for seg, is_span, conf in segments:
# Split segment into words to add per-word tooltips
import re as _re
words = _re.split(r'(\s+)', seg)
for word in words:
if not word:
continue
escaped = word.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace("\n", "<br>")
# Get avg prob for this word's characters
word_start = pos
word_end = pos + len(word)
word_probs = all_char_probs[word_start:word_end]
avg_prob = sum(word_probs) / max(1, len(word_probs))
tooltip = f"{avg_prob:.2f}"
if is_span:
html_parts.append(f'<span title="{tooltip}" style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px; cursor: help;">{escaped}</span>')
else:
html_parts.append(f'<span title="{tooltip}" style="cursor: help;">{escaped}</span>')
pos += len(word)
html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>'
st.markdown(html, unsafe_allow_html=True)
# Show extracted spans as dataframe
st.divider()
st.subheader("Extracted Spans")
import pandas as pd
span_data = [{"span": seg.strip(), "confidence": conf} for seg, is_span, conf in segments if is_span]
df = pd.DataFrame(span_data)
st.dataframe(
df,
use_container_width=True,
hide_index=True,
column_config={"confidence": st.column_config.ProgressColumn(min_value=0, max_value=1, format="%.2f")},
)