| import streamlit as st |
| import torch |
| import re |
| import os |
| from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
| LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2} |
| ID2LABEL = {v: k for k, v in LABEL2ID.items()} |
|
|
| import glob |
|
|
| MODEL_DIRS = { |
| "CE": "./span_model_ce", |
| "Focal": "./span_model_focal", |
| } |
|
|
| def discover_checkpoints(model_dir, prefix): |
| found = {} |
| for path in sorted(glob.glob(f"{model_dir}/checkpoint-*"), key=lambda p: int(p.split("-")[-1])): |
| name = f"{prefix} / {path.split('/')[-1]}" |
| found[name] = path |
| final_path = f"{model_dir}/final" |
| if os.path.exists(final_path): |
| found[f"{prefix} / final"] = final_path |
| return found |
|
|
| CHECKPOINTS = {} |
| for prefix, model_dir in MODEL_DIRS.items(): |
| CHECKPOINTS.update(discover_checkpoints(model_dir, prefix)) |
| if not CHECKPOINTS: |
| st.error("No checkpoints found.") |
| st.stop() |
|
|
|
|
| _current_model = {"path": None, "model": None, "tokenizer": None} |
|
|
| def load_model(checkpoint_path): |
| if _current_model["path"] == checkpoint_path: |
| return _current_model["tokenizer"], _current_model["model"] |
| |
| if _current_model["model"] is not None: |
| del _current_model["model"] |
| del _current_model["tokenizer"] |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
| model = AutoModelForTokenClassification.from_pretrained(checkpoint_path) |
| model.eval() |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| _current_model["path"] = checkpoint_path |
| _current_model["model"] = model |
| _current_model["tokenizer"] = tokenizer |
| return tokenizer, model |
|
|
|
|
| def strip_md(text): |
| text = re.sub(r'\[([^\]]*)\]\([^)]*\)', r'\1', text) |
| text = re.sub(r'\*\*([^*]*)\*\*', r'\1', text) |
| text = re.sub(r'\*([^*]*)\*', r'\1', text) |
| return text |
|
|
|
|
| def build_clean_to_original_map(original, cleaned): |
| """Build character mapping from cleaned text positions back to original text positions.""" |
| |
| mapping = [] |
| j = 0 |
| for i, ch in enumerate(cleaned): |
| while j < len(original) and original[j] != ch: |
| j += 1 |
| mapping.append(j) |
| j += 1 |
| return mapping |
|
|
|
|
| def predict_spans(tokenizer, model, title, text, threshold=0.5): |
| """Run inference and return list of (text, is_span) tuples for rendering.""" |
| device = next(model.parameters()).device |
|
|
| |
| clean_text = strip_md(text) |
|
|
| |
| title_enc = tokenizer(title, add_special_tokens=False) |
| text_enc = tokenizer(clean_text, add_special_tokens=False, return_offsets_mapping=True) |
|
|
| title_ids = title_enc["input_ids"] |
| text_ids = text_enc["input_ids"] |
| text_offsets = text_enc["offset_mapping"] |
|
|
| |
| input_ids = [tokenizer.cls_token_id] + title_ids + [tokenizer.sep_token_id] + text_ids + [tokenizer.sep_token_id] |
| attention_mask = [1] * len(input_ids) |
|
|
| |
| max_len = tokenizer.model_max_length |
| if max_len > 10000: |
| max_len = 512 |
| input_ids = input_ids[:max_len] |
| attention_mask = attention_mask[:max_len] |
|
|
| text_start = len(title_ids) + 2 |
| text_end = len(input_ids) - 1 |
|
|
| inputs = { |
| "input_ids": torch.tensor([input_ids], device=device), |
| "attention_mask": torch.tensor([attention_mask], device=device), |
| } |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits[0] |
| probs = torch.softmax(logits, dim=-1) |
|
|
| |
| clean_to_orig = build_clean_to_original_map(text, clean_text) |
|
|
| char_labels = [0] * len(text) |
| char_probs = [0.0] * len(text) |
| all_char_probs = [0.0] * len(text) |
| tokens_used = min(len(text_ids), text_end - text_start) |
|
|
| for i in range(tokens_used): |
| tok_idx = text_start + i |
| if tok_idx >= len(probs): |
| break |
| span_prob = (probs[tok_idx][LABEL2ID["B-SPAN"]] + probs[tok_idx][LABEL2ID["I-SPAN"]]).item() |
| if i < len(text_offsets): |
| clean_start, clean_end = text_offsets[i] |
| for cc in range(clean_start, min(clean_end, len(clean_text))): |
| if cc < len(clean_to_orig): |
| oc = clean_to_orig[cc] |
| if oc < len(text): |
| all_char_probs[oc] = max(all_char_probs[oc], span_prob) |
| if span_prob >= threshold: |
| for cc in range(clean_start, min(clean_end, len(clean_text))): |
| if cc < len(clean_to_orig): |
| oc = clean_to_orig[cc] |
| if oc < len(text): |
| char_labels[oc] = 1 |
| char_probs[oc] = max(char_probs[oc], span_prob) |
|
|
| |
| |
| i = 0 |
| while i < len(text): |
| if text[i].isspace(): |
| i += 1 |
| continue |
| |
| word_start = i |
| while i < len(text) and not text[i].isspace(): |
| i += 1 |
| word_end = i |
| |
| if any(char_labels[c] for c in range(word_start, word_end)): |
| max_prob = max(char_probs[c] for c in range(word_start, word_end)) |
| for c in range(word_start, word_end): |
| char_labels[c] = 1 |
| char_probs[c] = max(char_probs[c], max_prob) |
|
|
| |
| segments = [] |
| if not text: |
| return segments |
|
|
| current_label = char_labels[0] |
| current_start = 0 |
|
|
| for i in range(1, len(text)): |
| if char_labels[i] != current_label: |
| conf = sum(char_probs[current_start:i]) / max(1, i - current_start) if current_label == 1 else 0.0 |
| segments.append((text[current_start:i], current_label == 1, conf)) |
| current_start = i |
| current_label = char_labels[i] |
| conf = sum(char_probs[current_start:]) / max(1, len(text) - current_start) if current_label == 1 else 0.0 |
| segments.append((text[current_start:], current_label == 1, conf)) |
|
|
| return segments, all_char_probs |
|
|
|
|
| st.set_page_config(page_title="Span Extractor", layout="wide") |
| st.title("Span Extractor Inference") |
|
|
| checkpoint_names = list(CHECKPOINTS.keys()) |
| checkpoint = st.selectbox("Checkpoint", checkpoint_names, index=len(checkpoint_names) - 1) |
| tokenizer, model = load_model(CHECKPOINTS[checkpoint]) |
|
|
| threshold = st.slider("Span confidence threshold", 0.0, 1.0, 0.5, 0.05) |
|
|
| title = st.text_input("Title", placeholder="Enter article title...") |
| text = st.text_area("Text", height=300, placeholder="Enter article text...") |
|
|
| if st.button("Extract Spans") and title and text: |
| segments, all_char_probs = predict_spans(tokenizer, model, title, text, threshold) |
|
|
| if not any(is_span for _, is_span, _ in segments): |
| st.warning("No spans predicted.") |
| else: |
| span_count = sum(1 for seg, is_span, _ in segments if is_span) |
| st.caption(f"{span_count} span(s) detected") |
|
|
| |
| html_parts = [] |
| pos = 0 |
| for seg, is_span, conf in segments: |
| |
| import re as _re |
| words = _re.split(r'(\s+)', seg) |
| for word in words: |
| if not word: |
| continue |
| escaped = word.replace("&", "&").replace("<", "<").replace(">", ">").replace("\n", "<br>") |
| |
| word_start = pos |
| word_end = pos + len(word) |
| word_probs = all_char_probs[word_start:word_end] |
| avg_prob = sum(word_probs) / max(1, len(word_probs)) |
| tooltip = f"{avg_prob:.2f}" |
| if is_span: |
| html_parts.append(f'<span title="{tooltip}" style="background-color: #22c55e; color: white; padding: 1px 3px; border-radius: 3px; cursor: help;">{escaped}</span>') |
| else: |
| html_parts.append(f'<span title="{tooltip}" style="cursor: help;">{escaped}</span>') |
| pos += len(word) |
|
|
| html = f'<div style="font-size: 16px; line-height: 1.8; font-family: Georgia, serif;">{"".join(html_parts)}</div>' |
| st.markdown(html, unsafe_allow_html=True) |
|
|
| |
| st.divider() |
| st.subheader("Extracted Spans") |
| import pandas as pd |
| span_data = [{"span": seg.strip(), "confidence": conf} for seg, is_span, conf in segments if is_span] |
| df = pd.DataFrame(span_data) |
| st.dataframe( |
| df, |
| use_container_width=True, |
| hide_index=True, |
| column_config={"confidence": st.column_config.ProgressColumn(min_value=0, max_value=1, format="%.2f")}, |
| ) |
|
|