import argparse import json from pathlib import Path import numpy as np import onnxruntime as ort from transformers import AutoTokenizer from src.models.augment import augment, MAXLEN_TO_WINDOW from src.models.dataset import deduplicate_positions, flatten_to_examples from src.models.distillbert import reconstruct_triplets from src.schemas.labels import MARKER_MODE, SENTIMENT_LABELS BASE_TOKENIZER = "distilbert-base-uncased" def build_tokenizer(mode: str): tokenizer = AutoTokenizer.from_pretrained(BASE_TOKENIZER) if mode == "marker": tokenizer.add_special_tokens( {"additional_special_tokens": [MARKER_MODE.entity_start, MARKER_MODE.entity_end]} ) return tokenizer def _softmax(logits: np.ndarray) -> np.ndarray: exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) return exp / exp.sum(axis=-1, keepdims=True) def _tokenize_examples( examples: list[dict], tokenizer, max_len: int, ) -> dict[str, np.ndarray]: input_ids, attention_masks = [], [] for ex in examples: seg_a = ex["seg_a"] seg_b = ex["seg_b"] if seg_b is None: enc = tokenizer( seg_a, max_length=max_len, truncation=True, padding="max_length", return_tensors="np", ) else: enc = tokenizer( seg_a, seg_b, max_length=max_len, truncation="only_first", padding="max_length", return_tensors="np", ) input_ids.append(enc["input_ids"][0]) attention_masks.append(enc["attention_mask"][0]) return { "input_ids": np.array(input_ids, dtype=np.int64), "attention_mask": np.array(attention_masks, dtype=np.int64), } def _run_batched( session: ort.InferenceSession, inputs: dict[str, np.ndarray], batch_size: int, ) -> np.ndarray: n = inputs["input_ids"].shape[0] all_logits = [] for start in range(0, n, batch_size): end = min(start + batch_size, n) batch = {k: v[start:end] for k, v in inputs.items()} logits = session.run(None, batch)[0] all_logits.append(logits) return np.concatenate(all_logits, axis=0) def predict( samples: list[dict], session: ort.InferenceSession, tokenizer, mode: str, max_len: int = 256, batch_size: int = 32, deduplicate: bool = False, ) -> list[dict]: window_words = MAXLEN_TO_WINDOW[max_len] augmented = augment(samples, window_words) if deduplicate: augmented = deduplicate_positions(augmented) examples = flatten_to_examples(augmented, mode=mode) if not examples: return [{"id": s["id"], "entities": []} for s in samples] inputs = _tokenize_examples(examples, tokenizer, max_len) logits = _run_batched(session, inputs, batch_size) sentiments = list(SENTIMENT_LABELS.classes) probs = _softmax(logits) if mode in ("marker", "qa_m"): preds = np.argmax(probs, axis=-1) max_probs = probs.max(axis=-1) for ex, pred_id, conf in zip(examples, preds, max_probs): ex["predicted_label"] = sentiments[int(pred_id)] ex["confidence"] = float(conf) else: yes_probs = probs[:, 1] preds3, _ = reconstruct_triplets(yes_probs, np.zeros_like(yes_probs)) triplet_idx = 0 i = 0 while i < len(examples) - 2: pred_label = sentiments[preds3[triplet_idx]] triplet_conf = float(yes_probs[i:i + 3].max()) for j in range(3): examples[i + j]["predicted_label"] = pred_label examples[i + j]["confidence"] = triplet_conf triplet_idx += 1 i += 3 entity_preds: dict[tuple, tuple[str, float]] = {} for ex in examples: key = (ex["sample_id"], ex["entity_id"]) conf = ex.get("confidence", 0.0) if key not in entity_preds or conf > entity_preds[key][1]: entity_preds[key] = (ex["predicted_label"], conf) results = [] for s in samples: entities_out = [] for e in s["entities"]: key = (s["id"], e["entity_id"]) label, _ = entity_preds.get(key, ("neutral", 0.0)) entities_out.append({ "entity_id": e["entity_id"], "entity_text": e["entity_text"], "classification": label, }) results.append({"id": s["id"], "entities": entities_out}) return results def main(): parser = argparse.ArgumentParser(description="Run ONNX inference on raw input JSON") parser.add_argument("--onnx-path", required=True, help="Path to model.onnx") parser.add_argument("--mode", required=True, choices=("marker", "qa_m", "qa_b")) parser.add_argument("--data", required=True, help="Path to input JSON (assignment format)") parser.add_argument("--output", default=None, help="Output JSON path (default: stdout)") parser.add_argument("--max-len", type=int, default=256) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--deduplicate", action="store_true", help="Use one position per entity") args = parser.parse_args() mode = args.mode tokenizer = build_tokenizer(mode) session = ort.InferenceSession(args.onnx_path) with open(args.data, "r", encoding="utf-8") as f: samples = json.load(f) results = predict(samples, session, tokenizer, mode, args.max_len, args.batch_size, deduplicate=args.deduplicate) output_json = json.dumps(results, ensure_ascii=False, indent=2) if args.output: Path(args.output).write_text(output_json, encoding="utf-8") print(f"Saved {len(results)} predictions to {args.output}") else: print(output_json) if __name__ == "__main__": main()