File size: 4,072 Bytes
58b8e27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `programming-language-identification-100plus-lite` — ONNX Runtime\n",
"\n",
"Same model as the PyTorch demo, exported to ONNX (opset 20). No torch needed at inference time. CPU-friendly: ~57 texts/sec single-thread on commodity hardware (2.37× philomath-1209 on the same box).\n",
"\n",
"Run end-to-end in Colab or Jupyter."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "%%capture\n!pip install -q -U onnxruntime huggingface_hub numpy\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download ONNX model + label index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import json\nimport numpy as np\nimport onnxruntime as ort\nfrom huggingface_hub import hf_hub_download\n\nREPO = 'FrameByFrame/programming-language-identification-100plus-lite'\nonnx_path = hf_hub_download(REPO, 'onnx/model.onnx')\n# external weight blob lives next to the .onnx file\nhf_hub_download(REPO, 'onnx/model.onnx.data')\nlang2idx = json.loads(open(hf_hub_download(REPO, 'onnx/lang2idx.json')).read())\nmeta = json.loads(open(hf_hub_download(REPO, 'onnx/onnx_metadata.json')).read())\n\nsess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])\nidx2lang = {v: k for k, v in lang2idx.items()}\nMAX_LEN = meta['max_len']\nprint(f'{len(idx2lang)} labels | max_len={MAX_LEN} | providers={sess.get_providers()}')"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helpers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "def encode(texts, max_len=MAX_LEN):\n out = np.full((len(texts), max_len), 256, dtype=np.int64)\n for i, t in enumerate(texts):\n b = t.encode('utf-8', errors='replace')[:max_len]\n out[i, :len(b)] = np.frombuffer(b, dtype=np.uint8)\n return out\n\n\ndef softmax(logits, axis=-1):\n e = np.exp(logits - logits.max(axis=axis, keepdims=True))\n return e / e.sum(axis=axis, keepdims=True)\n\n\ndef predict(texts, top_k=3):\n logits = sess.run(None, {'byte_ids': encode(texts)})[0]\n probs = softmax(logits)\n top_i = np.argsort(-probs, axis=-1)[:, :top_k]\n return [[(idx2lang[int(j)], float(probs[r, j])) for j in row]\n for r, row in enumerate(top_i)]"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "samples = [\n \"def fib(n):\\n return n if n < 2 else fib(n-1) + fib(n-2)\",\n \"fn main() {\\n println!(\\\"hello, world\\\");\\n}\",\n \"package main\\nimport \\\"fmt\\\"\\nfunc main() { fmt.Println(\\\"hi\\\") }\",\n \"#include <stdio.h>\\nint main() { printf(\\\"hi\\\\n\\\"); return 0; }\",\n \"SELECT name FROM users WHERE id = 42;\",\n]\nfor text, top in zip(samples, predict(samples)):\n print(f'{top[0][0]:<14s} {top[0][1]:.3f} ({top[1][0]} {top[1][1]:.2f}, {top[2][0]} {top[2][1]:.2f}) | {text[:60]!r}')"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Throughput sanity check"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import time\nwarm = encode(samples * 13)[:64]\nfor _ in range(3):\n sess.run(None, {'byte_ids': warm})\nt0 = time.time()\nfor _ in range(40):\n sess.run(None, {'byte_ids': warm})\nelapsed = time.time() - t0\nprint(f'{40*64/elapsed:.0f} texts/sec ({elapsed:.2f}s for 40 batches of 64)')"
}
],
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.11"}
},
"nbformat": 4,
"nbformat_minor": 5
}
|