import os
import onnxruntime as ort
import torch
import torch.nn as nn
import numpy as np
from mentioned.model import ModelRegistry, LitMentionDetector
from mentioned.data import DataRegistry
class InferenceMentionDetector(nn.Module):
def __init__(self, encoder, mention_detector):
super().__init__()
self.encoder = encoder
self.mention_detector = mention_detector
def forward(self, input_ids, attention_mask, word_ids):
"""
Inputs (Tensors):
input_ids: (B, Seq_Len)
attention_mask: (B, Seq_Len)
word_ids: (B, Seq_Len) -> Word index per token, -1 padding.
Returns (Tensors):
start_probs: (B, Num_Words)
end_probs: (B, Num_Words, Num_Words)
"""
# B x N x D
word_embeddings = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids
)
# B x N, B x N x N
start_logits, end_logits = self.mention_detector(word_embeddings)
# B x N
start_probs = torch.sigmoid(start_logits)
# B x N x N
end_probs = torch.sigmoid(end_logits)
return start_probs, end_probs
class InferenceMentionLabeler(nn.Module):
def __init__(self, encoder, mention_detector, mention_labeler, id2label):
super().__init__()
self.encoder = encoder
self.mention_detector = mention_detector
self.mention_labeler = mention_labeler
self.id2label = id2label
def forward(self, input_ids, attention_mask, word_ids):
"""
Pure tensor forward pass for ONNX export.
Returns (Tensors):
start_probs: (B, N)
end_probs: (B, N, N)
label_probs: (B, N, N, C) or dummy empty tensor
"""
# 1. Encoder pass
word_embeddings = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
word_ids=word_ids
)
start_logits, end_logits = self.mention_detector(word_embeddings)
start_probs = torch.sigmoid(start_logits)
end_probs = torch.sigmoid(end_logits)
entity_logits = self.mention_labeler(word_embeddings)
label_probs = torch.softmax(entity_logits, dim=-1)
return start_probs, end_probs, label_probs
class MentionProcessor:
def __init__(self, tokenizer, max_length: int = 512):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, docs: list[list[str]]):
"""
Converts raw word lists into tensors.
Args:
docs: List of documents, where each doc is a list of words.
Example: [["Hello", "world"], ["Testing", "this"]]
"""
inputs = self.tokenizer(
docs,
is_split_into_words=True,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True,
return_attention_mask=True,
)
# We need a tensor where each token index maps to its word index.
# Special tokens (, , ) are mapped to -1.
batch_word_ids = []
for i in range(len(docs)):
# tokenizer.word_ids(i) returns [None, 0, 1, 1, 2, None]
w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)]
batch_word_ids.append(torch.tensor(w_ids))
word_ids_tensor = torch.stack(batch_word_ids)
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"word_ids": word_ids_tensor,
}
class ONNXMentionDetectorPipeline:
def __init__(self, model_path: str, tokenizer, threshold: float = 0.5):
# 1. Load the ONNX session
# 'CPUExecutionProvider' is perfect for HF Space free tier
self.session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
self.tokenizer = tokenizer
# We still use your existing MentionProcessor for the tokenization math
self.processor = MentionProcessor(tokenizer)
self.threshold = threshold
def predict(self, docs: list[list[str]]):
batch = self.processor(docs)
onnx_inputs = {
"input_ids": batch["input_ids"].numpy(),
"attention_mask": batch["attention_mask"].numpy(),
"word_ids": batch["word_ids"].numpy()
}
start_probs, end_probs = self.session.run(None, onnx_inputs)
# 5. Extraction Logic (Numpy version)
results = []
for i in range(len(docs)):
doc_mentions = []
doc_len = len(docs[i])
is_start = start_probs[i, :doc_len] > self.threshold
is_span = end_probs[i, :doc_len, :doc_len] > self.threshold
upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool))
combined_mask = is_span & is_start[:, None] & upper_tri
final_indices = np.argwhere(combined_mask)
for s_idx, e_idx in final_indices:
# XXX Considering end-prob as mention score!
score = end_probs[i, s_idx, e_idx]
doc_mentions.append({
"start": int(s_idx),
"end": int(e_idx),
"score": round(float(score), 4),
"text": " ".join(docs[i][s_idx:e_idx + 1]),
})
results.append(doc_mentions)
return results
class ONNXMentionLabelerPipeline:
def __init__(self, model_path: str, tokenizer, id2label: dict = None, threshold: float = 0.5):
# 1. Load the ONNX session
self.session = ort.InferenceSession(
model_path,
providers=['CPUExecutionProvider']
)
self.tokenizer = tokenizer
self.processor = MentionProcessor(tokenizer)
self.threshold = threshold
# Mapping for human-readable labels
self.id2label = id2label
def predict(self, docs: list[list[str]]):
batch = self.processor(docs)
onnx_inputs = {
"input_ids": batch["input_ids"].numpy(),
"attention_mask": batch["attention_mask"].numpy(),
"word_ids": batch["word_ids"].numpy()
}
# ONNX returns a list of outputs in order: [start_probs, end_probs, label_probs]
start_probs, end_probs, label_probs = self.session.run(None, onnx_inputs)
results = []
for i in range(len(docs)):
doc_mentions = []
doc_len = len(docs[i])
is_start = start_probs[i, :doc_len] > self.threshold
is_span = end_probs[i, :doc_len, :doc_len] > self.threshold
upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool))
combined_mask = is_span & is_start[:, None] & upper_tri
final_indices = np.argwhere(combined_mask)
for s_idx, e_idx in final_indices:
# XXX Considering end-prob as score!
det_score = float(end_probs[i, s_idx, e_idx])
class_probs = label_probs[i, s_idx, e_idx]
label_id = int(np.argmax(class_probs))
label_score = float(class_probs[label_id])
doc_mentions.append({
"start": int(s_idx),
"end": int(e_idx),
"text": " ".join(docs[i][s_idx : e_idx + 1]),
"score": round(det_score, 4),
"label": self.id2label.get(label_id, str(label_id)),
"label_score": round(label_score, 4),
})
results.append(doc_mentions)
return results
def create_inference_model(
repo_id: str,
encoder_id: str,
model_factory: str,
data_factory: str,
device: str = "cpu",
):
"""
Factory to load a trained model from HF Hub and wrap it for ONNX/Inference.
"""
data = DataRegistry.get(data_factory)()
fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id)
labeler = getattr(fresh_bundle, "mention_labeler", None)
l2id = getattr(fresh_bundle, "label2id", None)
lit_model = LitMentionDetector.from_pretrained(
repo_id,
tokenizer=fresh_bundle.tokenizer,
encoder=fresh_bundle.encoder,
mention_detector=fresh_bundle.mention_detector,
label2id=l2id,
mention_labeler=labeler,
# weights_only=False,
)
lit_model.to(device)
lit_model.eval()
if l2id is not None:
id2l = {v: k for k, v in l2id.items()}
inference_model = InferenceMentionLabeler(
encoder=lit_model.encoder,
mention_detector=lit_model.mention_detector,
mention_labeler=lit_model.mention_labeler,
id2label=id2l,
)
else:
inference_model = InferenceMentionDetector(
encoder=lit_model.encoder, mention_detector=lit_model.mention_detector
)
inference_model.tokenizer = lit_model.tokenizer
inference_model.max_length = lit_model.encoder.max_length
return inference_model.eval()
def compile_detector(model, output_dir="model_v1_onnx"):
"""ONNX export with dynamic axes for."""
model.eval()
os.makedirs(output_dir, exist_ok=True)
model.tokenizer.save_pretrained(output_dir)
onnx_path = os.path.join(output_dir, "model.onnx")
dynamic_axes = {
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"word_ids": {0: "batch", 1: "sequence"},
"start_probs": {0: "batch", 1: "num_words"},
"end_probs": {0: "batch", 1: "num_words", 2: "num_words"}
}
# Dummy inputs as before
dummy_inputs = (
torch.randint(0, 100, (1, 16), dtype=torch.long),
torch.ones((1, 16), dtype=torch.long),
torch.arange(16, dtype=torch.long).unsqueeze(0)
)
print("🚀 Re-exporting with legacy engine (dynamo=False)...")
torch.onnx.export(
model,
dummy_inputs,
onnx_path,
export_params=True,
opset_version=17, # Use 17 for maximum compatibility with legacy mode
do_constant_folding=True,
input_names=["input_ids", "attention_mask", "word_ids"],
output_names=["start_probs", "end_probs"],
dynamic_axes=dynamic_axes,
dynamo=False # <--- FORCE THE OLD, STABLE EXPORTER
)
print(f"✅ Exported to {output_dir}! Checking dimensions...")
# Verification:
sess = ort.InferenceSession(onnx_path)
for input_meta in sess.get_inputs():
print(f"Input '{input_meta.name}' shape: {input_meta.shape}")
def compile_labeler(model, output_dir="labeler_onnx"):
model.cpu().eval()
os.makedirs(output_dir, exist_ok=True)
model.tokenizer.save_pretrained(output_dir)
onnx_path = os.path.join(output_dir, "model.onnx")
print(f"🛠️ Exporting {model.__class__.__name__} to {onnx_path}...")
# Realistic dummy inputs: 2 batches, 16 tokens
dummy_inputs = (
torch.randint(0, 50000, (2, 16), dtype=torch.long),
torch.ones((2, 16), dtype=torch.long),
torch.arange(16, dtype=torch.long).unsqueeze(0).repeat(2, 1)
)
# Rename sequence axes to unique names to stop the merging warning
dynamic_axes = {
"input_ids": {0: "batch", 1: "seq_ids"},
"attention_mask": {0: "batch", 1: "seq_mask"},
"word_ids": {0: "batch", 1: "seq_words"},
"start_probs": {0: "batch", 1: "num_words"},
"end_probs": {0: "batch", 1: "num_words", 2: "num_words"},
"label_probs": {0: "batch", 1: "num_words", 2: "num_words", 3: "num_classes"}
}
torch.onnx.export(
model,
dummy_inputs,
onnx_path,
export_params=True,
opset_version=17, # Use 17 for better support of modern ops
do_constant_folding=True,
input_names=['input_ids', 'attention_mask', 'word_ids'],
output_names=['start_probs', 'end_probs', 'label_probs'],
dynamic_axes=dynamic_axes,
# THE FIXES:
training=torch.onnx.TrainingMode.EVAL,
dynamo=False
)
print("✅ Export finished successfully!")
if __name__ == "__main__":
# repo_id = "kadarakos/mention-detector-poc-dry-run"
# model_factory = "model_v1"
# inference_model_path = "model_v1_onnx"
model_factory = "model_v2"
data_factory = "litbank_entities"
inference_model_path = "model_v2_onnx"
repo_id = "kadarakos/entity-labeler-poc"
encoder_id = "distilroberta-base"
inference_model = create_inference_model(
repo_id,
encoder_id,
model_factory,
data_factory,
)
if isinstance(inference_model, InferenceMentionDetector):
compile_detector(inference_model, inference_model_path)
pipeline = ONNXMentionDetectorPipeline(
model_path=os.path.join(inference_model_path, "model.onnx"),
tokenizer=inference_model.tokenizer,
# XXX Sweet spot for this examples on this model, found by hand
threshold=0.3,
)
else:
print(inference_model)
compile_labeler(inference_model, inference_model_path)
pipeline = ONNXMentionLabelerPipeline(
model_path=os.path.join(inference_model_path, "model.onnx"),
tokenizer=inference_model.tokenizer,
threshold=0.5,
id2label=inference_model.id2label,
)
print("FUCK")
docs = [
"Does this model actually work?".split(),
"The name of the mage is Bubba.".split(),
"It was quite a sunny day when the model finally started working.".split(),
"Albert Einstein was a theoretical physicist who developed the theory of relativity".split(),
"Apple Inc. and Microsoft are competing in the cloud computing market".split(),
"New York City is often called the Big Apple".split(),
"The Great Barrier Reef is the world's largest coral reef system".split(),
"Marie Curie was the first woman to win a Nobel Prize".split(),
]
batch_mentions = pipeline.predict(docs)
for i, mentions in enumerate(batch_mentions):
print(" ".join(docs[i]))
preds = []
for mention in mentions:
preds.append((mention["text"], mention["label"]))
print(preds)