File size: 4,072 Bytes
58b8e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `programming-language-identification-100plus-lite` — ONNX Runtime\n",
    "\n",
    "Same model as the PyTorch demo, exported to ONNX (opset 20). No torch needed at inference time. CPU-friendly: ~57 texts/sec single-thread on commodity hardware (2.37× philomath-1209 on the same box).\n",
    "\n",
    "Run end-to-end in Colab or Jupyter."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "%%capture\n!pip install -q -U onnxruntime huggingface_hub numpy\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download ONNX model + label index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import json\nimport numpy as np\nimport onnxruntime as ort\nfrom huggingface_hub import hf_hub_download\n\nREPO = 'FrameByFrame/programming-language-identification-100plus-lite'\nonnx_path = hf_hub_download(REPO, 'onnx/model.onnx')\n# external weight blob lives next to the .onnx file\nhf_hub_download(REPO, 'onnx/model.onnx.data')\nlang2idx = json.loads(open(hf_hub_download(REPO, 'onnx/lang2idx.json')).read())\nmeta = json.loads(open(hf_hub_download(REPO, 'onnx/onnx_metadata.json')).read())\n\nsess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])\nidx2lang = {v: k for k, v in lang2idx.items()}\nMAX_LEN = meta['max_len']\nprint(f'{len(idx2lang)} labels | max_len={MAX_LEN} | providers={sess.get_providers()}')"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "def encode(texts, max_len=MAX_LEN):\n    out = np.full((len(texts), max_len), 256, dtype=np.int64)\n    for i, t in enumerate(texts):\n        b = t.encode('utf-8', errors='replace')[:max_len]\n        out[i, :len(b)] = np.frombuffer(b, dtype=np.uint8)\n    return out\n\n\ndef softmax(logits, axis=-1):\n    e = np.exp(logits - logits.max(axis=axis, keepdims=True))\n    return e / e.sum(axis=axis, keepdims=True)\n\n\ndef predict(texts, top_k=3):\n    logits = sess.run(None, {'byte_ids': encode(texts)})[0]\n    probs = softmax(logits)\n    top_i = np.argsort(-probs, axis=-1)[:, :top_k]\n    return [[(idx2lang[int(j)], float(probs[r, j])) for j in row]\n            for r, row in enumerate(top_i)]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "samples = [\n    \"def fib(n):\\n    return n if n < 2 else fib(n-1) + fib(n-2)\",\n    \"fn main() {\\n    println!(\\\"hello, world\\\");\\n}\",\n    \"package main\\nimport \\\"fmt\\\"\\nfunc main() { fmt.Println(\\\"hi\\\") }\",\n    \"#include <stdio.h>\\nint main() { printf(\\\"hi\\\\n\\\"); return 0; }\",\n    \"SELECT name FROM users WHERE id = 42;\",\n]\nfor text, top in zip(samples, predict(samples)):\n    print(f'{top[0][0]:<14s}  {top[0][1]:.3f}   ({top[1][0]} {top[1][1]:.2f}, {top[2][0]} {top[2][1]:.2f})  | {text[:60]!r}')"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Throughput sanity check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import time\nwarm = encode(samples * 13)[:64]\nfor _ in range(3):\n    sess.run(None, {'byte_ids': warm})\nt0 = time.time()\nfor _ in range(40):\n    sess.run(None, {'byte_ids': warm})\nelapsed = time.time() - t0\nprint(f'{40*64/elapsed:.0f} texts/sec  ({elapsed:.2f}s for 40 batches of 64)')"
  }
 ],
 "metadata": {
  "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
  "language_info": {"name": "python", "version": "3.11"}
 },
 "nbformat": 4,
 "nbformat_minor": 5
}