import streamlit as st import torch import re import os from transformers import AutoTokenizer, AutoModelForTokenClassification LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2} ID2LABEL = {v: k for k, v in LABEL2ID.items()} import glob MODEL_DIRS = { "CE": "./span_model_ce", "Focal": "./span_model_focal", } def discover_checkpoints(model_dir, prefix): found = {} for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])): name = f"{prefix} / {path.split('/')[-1]}" found[name] = path final_path = f"{model_dir}/final" if os.path.exists(final_path): found[f"{prefix} / final"] = final_path return found CHECKPOINTS = {} for prefix, model_dir in MODEL_DIRS.items(): CHECKPOINTS.update(discover_checkpoints(model_dir, prefix)) if not CHECKPOINTS: st.error("No checkpoints found.") st.stop() _current_model = {"path": None, "model": None, "tokenizer": None} def load_model(checkpoint_path): if _current_model["path"] == checkpoint_path: return _current_model["tokenizer"], _current_model["model"] # Free old model if _current_model["model"] is not None: del _current_model["model"] del _current_model["tokenizer"] if torch.cuda.is_available(): torch.cuda.empty_cache() tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) model = AutoModelForTokenClassification.from_pretrained(checkpoint_path) model.eval() if torch.cuda.is_available(): model = model.cuda() _current_model["path"] = checkpoint_path _current_model["model"] = model _current_model["tokenizer"] = tokenizer return tokenizer, model def strip_md(text): text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text) text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text) text = re.sub(r'\*([^*]*)\*', r'\1', text) return text def build_clean_to_original_map(original, cleaned): """Build character mapping from cleaned text positions back to original text positions.""" # Align cleaned to original using simple forward matching mapping = [] j = 0 for i, ch in enumerate(cleaned): while j < len(original) and original[j] != ch: j += 1 mapping.append(j) j += 1 return mapping def predict_spans(tokenizer, model, title, text, threshold=0.5): """Run inference and return list of (text, is_span) tuples for rendering.""" device = next(model.parameters()).device # Strip markdown for model input, keep original for display clean_text = strip_md(text) # Tokenize title and cleaned text title_enc = tokenizer(title, add_special_tokens=False) text_enc = tokenizer(clean_text, add_special_tokens=False, return_offsets_mapping=True) title_ids = title_enc["input_ids"] text_ids = text_enc["input_ids"] text_offsets = text_enc["offset_mapping"] # Build input: [CLS] title [SEP] text [SEP] input_ids = [tokenizer.cls_token_id] + title_ids + [tokenizer.sep_token_id] + text_ids + [tokenizer.sep_token_id] attention_mask = [1] * len(input_ids) # Truncate to model max length max_len = tokenizer.model_max_length if max_len > 10000: max_len = 512 input_ids = input_ids[:max_len] attention_mask = attention_mask[:max_len] text_start = len(title_ids) + 2 # CLS + title + SEP text_end = len(input_ids) - 1 # before final SEP inputs = { "input_ids": torch.tensor([input_ids], device=device), "attention_mask": torch.tensor([attention_mask], device=device), } with torch.no_grad(): logits = model(**inputs).logits[0] # (seq_len, 3) probs = torch.softmax(logits, dim=-1) # Map token probs from clean text back to original text clean_to_orig = build_clean_to_original_map(text, clean_text) char_labels = [0] * len(text) char_probs = [0.0] * len(text) all_char_probs = [0.0] * len(text) tokens_used = min(len(text_ids), text_end - text_start) for i in range(tokens_used): tok_idx = text_start + i if tok_idx >= len(probs): break span_prob = (probs[tok_idx][LABEL2ID["B-SPAN"]] + probs[tok_idx][LABEL2ID["I-SPAN"]]).item() if i < len(text_offsets): clean_start, clean_end = text_offsets[i] for cc in range(clean_start, min(clean_end, len(clean_text))): if cc < len(clean_to_orig): oc = clean_to_orig[cc] if oc < len(text): all_char_probs[oc] = max(all_char_probs[oc], span_prob) if span_prob >= threshold: for cc in range(clean_start, min(clean_end, len(clean_text))): if cc < len(clean_to_orig): oc = clean_to_orig[cc] if oc < len(text): char_labels[oc] = 1 char_probs[oc] = max(char_probs[oc], span_prob) # Expand labeled chars to cover full words (fix subword splits) # A "word" is a run of non-whitespace characters i = 0 while i < len(text): if text[i].isspace(): i += 1 continue # Find word boundary word_start = i while i < len(text) and not text[i].isspace(): i += 1 word_end = i # If any char in this word is labeled, label the whole word if any(char_labels[c] for c in range(word_start, word_end)): max_prob = max(char_probs[c] for c in range(word_start, word_end)) for c in range(word_start, word_end): char_labels[c] = 1 char_probs[c] = max(char_probs[c], max_prob) # Build segments with average confidence per span segments = [] if not text: return segments current_label = char_labels[0] current_start = 0 for i in range(1, len(text)): if char_labels[i] != current_label: conf = sum(char_probs[current_start:i]) / max(1, i - current_start) if current_label == 1 else 0.0 segments.append((text[current_start:i], current_label == 1, conf)) current_start = i current_label = char_labels[i] conf = sum(char_probs[current_start:]) / max(1, len(text) - current_start) if current_label == 1 else 0.0 segments.append((text[current_start:], current_label == 1, conf)) return segments, all_char_probs st.set_page_config(page_title="Span Extractor", layout="wide") st.title("Span Extractor Inference") checkpoint_names = list(CHECKPOINTS.keys()) checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1) tokenizer, model = load_model(CHECKPOINTS[checkpoint]) threshold = st.slider("Span confidence threshold", 0.0, 1.0, 0.5, 0.05) title = st.text_input("Title", placeholder="Enter article title...") text = st.text_area("Text", height=300, placeholder="Enter article text...") if st.button("Extract Spans") and title and text: segments, all_char_probs = predict_spans(tokenizer, model, title, text, threshold) if not any(is_span for _, is_span, _ in segments): st.warning("No spans predicted.") else: span_count = sum(1 for seg, is_span, _ in segments if is_span) st.caption(f"{span_count} span(s) detected") # Render with green background for spans, tooltips on all words html_parts = [] pos = 0 for seg, is_span, conf in segments: # Split segment into words to add per-word tooltips import re as _re words = _re.split(r'(\s+)', seg) for word in words: if not word: continue escaped = word.replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "
") # Get avg prob for this word's characters word_start = pos word_end = pos + len(word) word_probs = all_char_probs[word_start:word_end] avg_prob = sum(word_probs) / max(1, len(word_probs)) tooltip = f"{avg_prob:.2f}" if is_span: html_parts.append(f'{escaped}') else: html_parts.append(f'{escaped}') pos += len(word) html = f'
{"".join(html_parts)}
' st.markdown(html, unsafe_allow_html=True) # Show extracted spans as dataframe st.divider() st.subheader("Extracted Spans") import pandas as pd span_data = [{"span": seg.strip(), "confidence": conf} for seg, is_span, conf in segments if is_span] df = pd.DataFrame(span_data) st.dataframe( df, use_container_width=True, hide_index=True, column_config={"confidence": st.column_config.ProgressColumn(min_value=0, max_value=1, format="%.2f")}, )