ritammehta's picture
Upload app.py with huggingface_hub
2855bad verified
# v1.7 - Substring-level marker highlighting (token classifier only)
"""
Havelock.AI - Token Span API
Runs the trained MultiLabelTokenClassifier (HavelockAI/bert-token-classifier)
and returns character-level span predictions for Tier 1 markers (F1 >= 0.50).
Sentence-level scoring comes from the production Space (thestalwart/havelock-demo).
"""
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
import json
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model repository
TOKEN_MODEL_REPO = "HavelockAI/bert-token-classifier"
# Tier 1 markers: F1 >= 0.50 from manifest
TIER1_MARKERS = {
"oral_vocative",
"literate_technical_abbreviation",
"oral_phatic_check",
"oral_imperative",
"oral_specific_place",
"literate_citation",
"literate_agentless_passive",
"oral_rhetorical_question",
"oral_inclusive_we",
"oral_second_person",
"oral_named_individual",
"literate_nominalization",
"literate_probability",
}
def load_token_classifier():
"""Load the token-level classifier from HuggingFace Hub."""
import sys
print("Loading token classifier...", flush=True)
print("Step 1: Loading tokenizer...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(TOKEN_MODEL_REPO)
print(f" Tokenizer loaded: {type(tokenizer).__name__}", flush=True)
print("Step 2: Loading model...", flush=True)
sys.stdout.flush()
model = AutoModel.from_pretrained(
TOKEN_MODEL_REPO,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
)
print(f" Model loaded: {type(model).__name__}", flush=True)
print("Step 3: Moving to device...", flush=True)
model.to(DEVICE)
model.eval()
print(f" Model on {DEVICE}", flush=True)
print("Step 4: Loading type map...", flush=True)
type_map_path = hf_hub_download(TOKEN_MODEL_REPO, "type_to_idx.json")
with open(type_map_path) as f:
type_to_idx = json.load(f)
idx_to_type = {v: k for k, v in type_to_idx.items()}
print(f"Token classifier loaded! ({len(type_to_idx)} marker types)", flush=True)
return tokenizer, model, idx_to_type
# Load model at startup
tokenizer, model, idx_to_type = load_token_classifier()
def _emit_span(spans, text, offset_mapping, start_tok, end_tok, marker_name):
"""Convert token indices to a character-level span dict."""
char_start = int(offset_mapping[start_tok][0])
char_end = int(offset_mapping[end_tok - 1][1])
if char_end > char_start:
category = "oral" if marker_name.startswith("oral_") else "literate"
spans.append({
"text": text[char_start:char_end],
"marker": marker_name,
"category": category,
"start": char_start,
"end": char_end,
})
def predict_spans(text, tier1_only=True):
"""Run token classifier and return character-level spans.
Args:
text: Input text (single sentence or short passage)
tier1_only: If True, only return Tier 1 markers (F1 >= 0.50)
Returns:
List of span dicts sorted by start position.
"""
eligible = TIER1_MARKERS if tier1_only else None
encoding = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
return_offsets_mapping=True,
)
offset_mapping = encoding.pop("offset_mapping")[0]
input_ids = encoding["input_ids"].to(DEVICE)
attention_mask = encoding["attention_mask"].to(DEVICE)
with torch.no_grad():
if hasattr(model, "decode"):
preds = model.decode(input_ids, attention_mask)
else:
logits = model(input_ids, attention_mask)
preds = logits.argmax(dim=-1)
# preds shape: (1, seq_len, num_types) where values are 0=O, 1=B, 2=I
preds = preds[0]
seq_len = attention_mask.sum().item()
spans = []
for type_idx in range(preds.shape[1]):
marker_name = idx_to_type.get(type_idx)
if marker_name is None:
continue
if eligible is not None and marker_name not in eligible:
continue
span_start_tok = None
for tok_pos in range(seq_len):
tag = preds[tok_pos, type_idx].item()
offsets = offset_mapping[tok_pos].tolist()
# Skip special tokens (offset 0,0)
if offsets[0] == 0 and offsets[1] == 0 and tok_pos > 0:
if span_start_tok is not None:
_emit_span(spans, text, offset_mapping, span_start_tok, tok_pos, marker_name)
span_start_tok = None
continue
if tag == 1: # B
if span_start_tok is not None:
_emit_span(spans, text, offset_mapping, span_start_tok, tok_pos, marker_name)
span_start_tok = tok_pos
elif tag == 2: # I
if span_start_tok is None:
span_start_tok = tok_pos
else: # O
if span_start_tok is not None:
_emit_span(spans, text, offset_mapping, span_start_tok, tok_pos, marker_name)
span_start_tok = None
if span_start_tok is not None:
_emit_span(spans, text, offset_mapping, span_start_tok, seq_len, marker_name)
spans.sort(key=lambda s: (s["start"], s["end"]))
return spans
def analyze_spans_api(text):
"""JSON API: return token-level spans for input text.
Splits text into sentences and returns spans per sentence.
"""
import re
if not text or len(text.strip()) < 3:
return {"error": "Please enter at least 3 characters of text."}
sentences = re.split(r'(?<=[.!?])\s+', text)
sentences = [s.strip() for s in sentences if s.strip()]
results = []
for sent in sentences:
if len(sent.split()) < 2:
continue
spans = predict_spans(sent)
results.append({
"text": sent,
"spans": spans,
})
return {"sentences": results}
# Build interface
with gr.Blocks(title="Havelock.AI - Token Span API", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Havelock.AI - Token Span API
Returns substring-level marker predictions from the trained token classifier.
Only Tier 1 markers (F1 >= 0.50) are included.
Use this alongside the production Space for full analysis.
""")
text_input = gr.Textbox(
label="Enter text",
placeholder="Paste text here...",
lines=4
)
output = gr.JSON(label="Span Predictions")
analyze_btn = gr.Button("Predict Spans", variant="primary")
gr.Examples([
["Tell me, O Muse, of that ingenious hero who travelled far and wide."],
["We will fight on the beaches, we will fight on the landing grounds."],
["The analysis of variance revealed a statistically significant effect."],
["So like, I was just thinking about this the other day, right?"],
], inputs=text_input)
analyze_btn.click(fn=analyze_spans_api, inputs=text_input, outputs=output)
# Hidden API endpoint
api_input = gr.Textbox(visible=False)
api_output = gr.JSON(visible=False)
api_btn = gr.Button(visible=False)
api_btn.click(fn=analyze_spans_api, inputs=api_input, outputs=api_output, api_name="analyze")
demo.launch()