| import argparse |
| import json |
| from pathlib import Path |
| import numpy as np |
| import onnxruntime as ort |
| from transformers import AutoTokenizer |
| from src.models.augment import augment, MAXLEN_TO_WINDOW |
| from src.models.dataset import deduplicate_positions, flatten_to_examples |
| from src.models.distillbert import reconstruct_triplets |
| from src.schemas.labels import MARKER_MODE, SENTIMENT_LABELS |
|
|
|
|
| BASE_TOKENIZER = "distilbert-base-uncased" |
|
|
| def build_tokenizer(mode: str): |
| tokenizer = AutoTokenizer.from_pretrained(BASE_TOKENIZER) |
| if mode == "marker": |
| tokenizer.add_special_tokens( |
| {"additional_special_tokens": [MARKER_MODE.entity_start, MARKER_MODE.entity_end]} |
| ) |
| return tokenizer |
|
|
|
|
| def _softmax(logits: np.ndarray) -> np.ndarray: |
| exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) |
| return exp / exp.sum(axis=-1, keepdims=True) |
|
|
|
|
| def _tokenize_examples( |
| examples: list[dict], tokenizer, max_len: int, |
| ) -> dict[str, np.ndarray]: |
| input_ids, attention_masks = [], [] |
| for ex in examples: |
| seg_a = ex["seg_a"] |
| seg_b = ex["seg_b"] |
|
|
| if seg_b is None: |
| enc = tokenizer( |
| seg_a, |
| max_length=max_len, |
| truncation=True, |
| padding="max_length", |
| return_tensors="np", |
| ) |
| else: |
| enc = tokenizer( |
| seg_a, seg_b, |
| max_length=max_len, |
| truncation="only_first", |
| padding="max_length", |
| return_tensors="np", |
| ) |
| input_ids.append(enc["input_ids"][0]) |
| attention_masks.append(enc["attention_mask"][0]) |
|
|
| return { |
| "input_ids": np.array(input_ids, dtype=np.int64), |
| "attention_mask": np.array(attention_masks, dtype=np.int64), |
| } |
|
|
|
|
| def _run_batched( |
| session: ort.InferenceSession, |
| inputs: dict[str, np.ndarray], |
| batch_size: int, |
| ) -> np.ndarray: |
| n = inputs["input_ids"].shape[0] |
| all_logits = [] |
| for start in range(0, n, batch_size): |
| end = min(start + batch_size, n) |
| batch = {k: v[start:end] for k, v in inputs.items()} |
| logits = session.run(None, batch)[0] |
| all_logits.append(logits) |
| return np.concatenate(all_logits, axis=0) |
|
|
|
|
| def predict( |
| samples: list[dict], |
| session: ort.InferenceSession, |
| tokenizer, |
| mode: str, |
| max_len: int = 256, |
| batch_size: int = 32, |
| deduplicate: bool = False, |
| ) -> list[dict]: |
| window_words = MAXLEN_TO_WINDOW[max_len] |
| augmented = augment(samples, window_words) |
| if deduplicate: |
| augmented = deduplicate_positions(augmented) |
| examples = flatten_to_examples(augmented, mode=mode) |
|
|
| if not examples: |
| return [{"id": s["id"], "entities": []} for s in samples] |
|
|
| inputs = _tokenize_examples(examples, tokenizer, max_len) |
| logits = _run_batched(session, inputs, batch_size) |
| sentiments = list(SENTIMENT_LABELS.classes) |
|
|
| probs = _softmax(logits) |
|
|
| if mode in ("marker", "qa_m"): |
| preds = np.argmax(probs, axis=-1) |
| max_probs = probs.max(axis=-1) |
| for ex, pred_id, conf in zip(examples, preds, max_probs): |
| ex["predicted_label"] = sentiments[int(pred_id)] |
| ex["confidence"] = float(conf) |
|
|
| else: |
| yes_probs = probs[:, 1] |
| preds3, _ = reconstruct_triplets(yes_probs, np.zeros_like(yes_probs)) |
|
|
| triplet_idx = 0 |
| i = 0 |
| while i < len(examples) - 2: |
| pred_label = sentiments[preds3[triplet_idx]] |
| triplet_conf = float(yes_probs[i:i + 3].max()) |
| for j in range(3): |
| examples[i + j]["predicted_label"] = pred_label |
| examples[i + j]["confidence"] = triplet_conf |
| triplet_idx += 1 |
| i += 3 |
|
|
| entity_preds: dict[tuple, tuple[str, float]] = {} |
| for ex in examples: |
| key = (ex["sample_id"], ex["entity_id"]) |
| conf = ex.get("confidence", 0.0) |
| if key not in entity_preds or conf > entity_preds[key][1]: |
| entity_preds[key] = (ex["predicted_label"], conf) |
|
|
| results = [] |
| for s in samples: |
| entities_out = [] |
| for e in s["entities"]: |
| key = (s["id"], e["entity_id"]) |
| label, _ = entity_preds.get(key, ("neutral", 0.0)) |
| entities_out.append({ |
| "entity_id": e["entity_id"], |
| "entity_text": e["entity_text"], |
| "classification": label, |
| }) |
| results.append({"id": s["id"], "entities": entities_out}) |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run ONNX inference on raw input JSON") |
| parser.add_argument("--onnx-path", required=True, help="Path to model.onnx") |
| parser.add_argument("--mode", required=True, choices=("marker", "qa_m", "qa_b")) |
| parser.add_argument("--data", required=True, help="Path to input JSON (assignment format)") |
| parser.add_argument("--output", default=None, help="Output JSON path (default: stdout)") |
| parser.add_argument("--max-len", type=int, default=256) |
| parser.add_argument("--batch-size", type=int, default=32) |
| parser.add_argument("--deduplicate", action="store_true", help="Use one position per entity") |
| args = parser.parse_args() |
|
|
| mode = args.mode |
| tokenizer = build_tokenizer(mode) |
| session = ort.InferenceSession(args.onnx_path) |
|
|
| with open(args.data, "r", encoding="utf-8") as f: |
| samples = json.load(f) |
|
|
| results = predict(samples, session, tokenizer, mode, args.max_len, args.batch_size, deduplicate=args.deduplicate) |
|
|
| output_json = json.dumps(results, ensure_ascii=False, indent=2) |
| if args.output: |
| Path(args.output).write_text(output_json, encoding="utf-8") |
| print(f"Saved {len(results)} predictions to {args.output}") |
| else: |
| print(output_json) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|