| | import os |
| | import onnxruntime as ort |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| |
|
| | from mentioned.model import ModelRegistry, LitMentionDetector |
| | from mentioned.data import DataRegistry |
| |
|
| |
|
| | class InferenceMentionDetector(nn.Module): |
| | def __init__(self, encoder, mention_detector): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.mention_detector = mention_detector |
| |
|
| | def forward(self, input_ids, attention_mask, word_ids): |
| | """ |
| | Inputs (Tensors): |
| | input_ids: (B, Seq_Len) |
| | attention_mask: (B, Seq_Len) |
| | word_ids: (B, Seq_Len) -> Word index per token, -1 padding. |
| | |
| | Returns (Tensors): |
| | start_probs: (B, Num_Words) |
| | end_probs: (B, Num_Words, Num_Words) |
| | """ |
| | |
| | word_embeddings = self.encoder( |
| | input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids |
| | ) |
| | |
| | start_logits, end_logits = self.mention_detector(word_embeddings) |
| | |
| | start_probs = torch.sigmoid(start_logits) |
| | |
| | end_probs = torch.sigmoid(end_logits) |
| |
|
| | return start_probs, end_probs |
| |
|
| |
|
| | class InferenceMentionLabeler(nn.Module): |
| | def __init__(self, encoder, mention_detector, mention_labeler, id2label): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.mention_detector = mention_detector |
| | self.mention_labeler = mention_labeler |
| | self.id2label = id2label |
| |
|
| | def forward(self, input_ids, attention_mask, word_ids): |
| | """ |
| | Pure tensor forward pass for ONNX export. |
| | |
| | Returns (Tensors): |
| | start_probs: (B, N) |
| | end_probs: (B, N, N) |
| | label_probs: (B, N, N, C) or dummy empty tensor |
| | """ |
| | |
| | word_embeddings = self.encoder( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | word_ids=word_ids |
| | ) |
| | start_logits, end_logits = self.mention_detector(word_embeddings) |
| | start_probs = torch.sigmoid(start_logits) |
| | end_probs = torch.sigmoid(end_logits) |
| | entity_logits = self.mention_labeler(word_embeddings) |
| | label_probs = torch.softmax(entity_logits, dim=-1) |
| | return start_probs, end_probs, label_probs |
| |
|
| |
|
| | class MentionProcessor: |
| | def __init__(self, tokenizer, max_length: int = 512): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __call__(self, docs: list[list[str]]): |
| | """ |
| | Converts raw word lists into tensors. |
| | Args: |
| | docs: List of documents, where each doc is a list of words. |
| | Example: [["Hello", "world"], ["Testing", "this"]] |
| | """ |
| | inputs = self.tokenizer( |
| | docs, |
| | is_split_into_words=True, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=self.max_length, |
| | padding=True, |
| | return_attention_mask=True, |
| | ) |
| |
|
| | |
| | |
| | batch_word_ids = [] |
| | for i in range(len(docs)): |
| | |
| | w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)] |
| | batch_word_ids.append(torch.tensor(w_ids)) |
| | word_ids_tensor = torch.stack(batch_word_ids) |
| |
|
| | return { |
| | "input_ids": inputs["input_ids"], |
| | "attention_mask": inputs["attention_mask"], |
| | "word_ids": word_ids_tensor, |
| | } |
| |
|
| |
|
| | class ONNXMentionDetectorPipeline: |
| | def __init__(self, model_path: str, tokenizer, threshold: float = 0.5): |
| | |
| | |
| | self.session = ort.InferenceSession( |
| | model_path, |
| | providers=['CPUExecutionProvider'] |
| | ) |
| | self.tokenizer = tokenizer |
| | |
| | self.processor = MentionProcessor(tokenizer) |
| | self.threshold = threshold |
| |
|
| | def predict(self, docs: list[list[str]]): |
| | batch = self.processor(docs) |
| | onnx_inputs = { |
| | "input_ids": batch["input_ids"].numpy(), |
| | "attention_mask": batch["attention_mask"].numpy(), |
| | "word_ids": batch["word_ids"].numpy() |
| | } |
| | start_probs, end_probs = self.session.run(None, onnx_inputs) |
| |
|
| | |
| | results = [] |
| | for i in range(len(docs)): |
| | doc_mentions = [] |
| | doc_len = len(docs[i]) |
| | is_start = start_probs[i, :doc_len] > self.threshold |
| | is_span = end_probs[i, :doc_len, :doc_len] > self.threshold |
| | upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool)) |
| | combined_mask = is_span & is_start[:, None] & upper_tri |
| | final_indices = np.argwhere(combined_mask) |
| |
|
| | for s_idx, e_idx in final_indices: |
| | |
| | score = end_probs[i, s_idx, e_idx] |
| | doc_mentions.append({ |
| | "start": int(s_idx), |
| | "end": int(e_idx), |
| | "score": round(float(score), 4), |
| | "text": " ".join(docs[i][s_idx:e_idx + 1]), |
| | }) |
| | results.append(doc_mentions) |
| |
|
| | return results |
| |
|
| |
|
| | class ONNXMentionLabelerPipeline: |
| | def __init__(self, model_path: str, tokenizer, id2label: dict = None, threshold: float = 0.5): |
| | |
| | self.session = ort.InferenceSession( |
| | model_path, |
| | providers=['CPUExecutionProvider'] |
| | ) |
| | self.tokenizer = tokenizer |
| | self.processor = MentionProcessor(tokenizer) |
| | self.threshold = threshold |
| | |
| | self.id2label = id2label |
| |
|
| | def predict(self, docs: list[list[str]]): |
| | batch = self.processor(docs) |
| | onnx_inputs = { |
| | "input_ids": batch["input_ids"].numpy(), |
| | "attention_mask": batch["attention_mask"].numpy(), |
| | "word_ids": batch["word_ids"].numpy() |
| | } |
| | |
| | |
| | start_probs, end_probs, label_probs = self.session.run(None, onnx_inputs) |
| |
|
| | results = [] |
| | for i in range(len(docs)): |
| | doc_mentions = [] |
| | doc_len = len(docs[i]) |
| | is_start = start_probs[i, :doc_len] > self.threshold |
| | is_span = end_probs[i, :doc_len, :doc_len] > self.threshold |
| | upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool)) |
| | combined_mask = is_span & is_start[:, None] & upper_tri |
| | final_indices = np.argwhere(combined_mask) |
| |
|
| | for s_idx, e_idx in final_indices: |
| | |
| | det_score = float(end_probs[i, s_idx, e_idx]) |
| | class_probs = label_probs[i, s_idx, e_idx] |
| | label_id = int(np.argmax(class_probs)) |
| | label_score = float(class_probs[label_id]) |
| | |
| | doc_mentions.append({ |
| | "start": int(s_idx), |
| | "end": int(e_idx), |
| | "text": " ".join(docs[i][s_idx : e_idx + 1]), |
| | "score": round(det_score, 4), |
| | "label": self.id2label.get(label_id, str(label_id)), |
| | "label_score": round(label_score, 4), |
| | }) |
| | results.append(doc_mentions) |
| |
|
| | return results |
| |
|
| |
|
| | def create_inference_model( |
| | repo_id: str, |
| | encoder_id: str, |
| | model_factory: str, |
| | data_factory: str, |
| | device: str = "cpu", |
| | ): |
| | """ |
| | Factory to load a trained model from HF Hub and wrap it for ONNX/Inference. |
| | """ |
| | data = DataRegistry.get(data_factory)() |
| | fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id) |
| | labeler = getattr(fresh_bundle, "mention_labeler", None) |
| | l2id = getattr(fresh_bundle, "label2id", None) |
| |
|
| | lit_model = LitMentionDetector.from_pretrained( |
| | repo_id, |
| | tokenizer=fresh_bundle.tokenizer, |
| | encoder=fresh_bundle.encoder, |
| | mention_detector=fresh_bundle.mention_detector, |
| | label2id=l2id, |
| | mention_labeler=labeler, |
| | |
| | ) |
| | lit_model.to(device) |
| | lit_model.eval() |
| | if l2id is not None: |
| | id2l = {v: k for k, v in l2id.items()} |
| | inference_model = InferenceMentionLabeler( |
| | encoder=lit_model.encoder, |
| | mention_detector=lit_model.mention_detector, |
| | mention_labeler=lit_model.mention_labeler, |
| | id2label=id2l, |
| | ) |
| | else: |
| | inference_model = InferenceMentionDetector( |
| | encoder=lit_model.encoder, mention_detector=lit_model.mention_detector |
| | ) |
| | inference_model.tokenizer = lit_model.tokenizer |
| | inference_model.max_length = lit_model.encoder.max_length |
| |
|
| | return inference_model.eval() |
| |
|
| |
|
| | def compile_detector(model, output_dir="model_v1_onnx"): |
| | """ONNX export with dynamic axes for.""" |
| | model.eval() |
| | os.makedirs(output_dir, exist_ok=True) |
| | model.tokenizer.save_pretrained(output_dir) |
| | onnx_path = os.path.join(output_dir, "model.onnx") |
| | dynamic_axes = { |
| | "input_ids": {0: "batch", 1: "sequence"}, |
| | "attention_mask": {0: "batch", 1: "sequence"}, |
| | "word_ids": {0: "batch", 1: "sequence"}, |
| | "start_probs": {0: "batch", 1: "num_words"}, |
| | "end_probs": {0: "batch", 1: "num_words", 2: "num_words"} |
| | } |
| |
|
| | |
| | dummy_inputs = ( |
| | torch.randint(0, 100, (1, 16), dtype=torch.long), |
| | torch.ones((1, 16), dtype=torch.long), |
| | torch.arange(16, dtype=torch.long).unsqueeze(0) |
| | ) |
| |
|
| | print("🚀 Re-exporting with legacy engine (dynamo=False)...") |
| |
|
| | torch.onnx.export( |
| | model, |
| | dummy_inputs, |
| | onnx_path, |
| | export_params=True, |
| | opset_version=17, |
| | do_constant_folding=True, |
| | input_names=["input_ids", "attention_mask", "word_ids"], |
| | output_names=["start_probs", "end_probs"], |
| | dynamic_axes=dynamic_axes, |
| | dynamo=False |
| | ) |
| | print(f"✅ Exported to {output_dir}! Checking dimensions...") |
| |
|
| | |
| | sess = ort.InferenceSession(onnx_path) |
| | for input_meta in sess.get_inputs(): |
| | print(f"Input '{input_meta.name}' shape: {input_meta.shape}") |
| |
|
| |
|
| | def compile_labeler(model, output_dir="labeler_onnx"): |
| | model.cpu().eval() |
| | os.makedirs(output_dir, exist_ok=True) |
| | model.tokenizer.save_pretrained(output_dir) |
| | onnx_path = os.path.join(output_dir, "model.onnx") |
| |
|
| | print(f"🛠️ Exporting {model.__class__.__name__} to {onnx_path}...") |
| |
|
| | |
| | dummy_inputs = ( |
| | torch.randint(0, 50000, (2, 16), dtype=torch.long), |
| | torch.ones((2, 16), dtype=torch.long), |
| | torch.arange(16, dtype=torch.long).unsqueeze(0).repeat(2, 1) |
| | ) |
| |
|
| | |
| | dynamic_axes = { |
| | "input_ids": {0: "batch", 1: "seq_ids"}, |
| | "attention_mask": {0: "batch", 1: "seq_mask"}, |
| | "word_ids": {0: "batch", 1: "seq_words"}, |
| | "start_probs": {0: "batch", 1: "num_words"}, |
| | "end_probs": {0: "batch", 1: "num_words", 2: "num_words"}, |
| | "label_probs": {0: "batch", 1: "num_words", 2: "num_words", 3: "num_classes"} |
| | } |
| |
|
| | torch.onnx.export( |
| | model, |
| | dummy_inputs, |
| | onnx_path, |
| | export_params=True, |
| | opset_version=17, |
| | do_constant_folding=True, |
| | input_names=['input_ids', 'attention_mask', 'word_ids'], |
| | output_names=['start_probs', 'end_probs', 'label_probs'], |
| | dynamic_axes=dynamic_axes, |
| | |
| | training=torch.onnx.TrainingMode.EVAL, |
| | dynamo=False |
| | ) |
| | print("✅ Export finished successfully!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | model_factory = "model_v2" |
| | data_factory = "litbank_entities" |
| | inference_model_path = "model_v2_onnx" |
| | repo_id = "kadarakos/entity-labeler-poc" |
| | encoder_id = "distilroberta-base" |
| | inference_model = create_inference_model( |
| | repo_id, |
| | encoder_id, |
| | model_factory, |
| | data_factory, |
| | ) |
| | if isinstance(inference_model, InferenceMentionDetector): |
| | compile_detector(inference_model, inference_model_path) |
| | pipeline = ONNXMentionDetectorPipeline( |
| | model_path=os.path.join(inference_model_path, "model.onnx"), |
| | tokenizer=inference_model.tokenizer, |
| | |
| | threshold=0.3, |
| | ) |
| | else: |
| | print(inference_model) |
| | compile_labeler(inference_model, inference_model_path) |
| | pipeline = ONNXMentionLabelerPipeline( |
| | model_path=os.path.join(inference_model_path, "model.onnx"), |
| | tokenizer=inference_model.tokenizer, |
| | threshold=0.5, |
| | id2label=inference_model.id2label, |
| | ) |
| | print("FUCK") |
| | docs = [ |
| | "Does this model actually work?".split(), |
| | "The name of the mage is Bubba.".split(), |
| | "It was quite a sunny day when the model finally started working.".split(), |
| | "Albert Einstein was a theoretical physicist who developed the theory of relativity".split(), |
| | "Apple Inc. and Microsoft are competing in the cloud computing market".split(), |
| | "New York City is often called the Big Apple".split(), |
| | "The Great Barrier Reef is the world's largest coral reef system".split(), |
| | "Marie Curie was the first woman to win a Nobel Prize".split(), |
| | ] |
| |
|
| | batch_mentions = pipeline.predict(docs) |
| | for i, mentions in enumerate(batch_mentions): |
| | print(" ".join(docs[i])) |
| | preds = [] |
| | for mention in mentions: |
| | preds.append((mention["text"], mention["label"])) |
| | print(preds) |
| |
|