{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `programming-language-identification-100plus-lite` — PyTorch\n", "\n", "2.35M-param byte-level classifier across 107 programming languages. No tokenizer; raw UTF-8 bytes padded to 1023.\n", "\n", "Self-contained: this notebook inlines the model definition (vendored from PleIAs/CommonLingua, Apache-2.0) and downloads the checkpoint from the Hub. Run end-to-end in Colab or Jupyter." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "%%capture\n!pip install -q -U torch huggingface_hub\n" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model definition (ByteHybrid — vendored from PleIAs/CommonLingua, Apache-2.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ByteNgramEmbed(nn.Module):\n def __init__(self, num_buckets=4096, embed_dim=64, n=3):\n super().__init__()\n self.n, self.num_buckets = n, num_buckets\n self.embed = nn.Embedding(num_buckets, embed_dim)\n\n def forward(self, byte_ids):\n B, T = byte_ids.shape\n clamped = byte_ids.clamp(max=255)\n padded = F.pad(clamped, (0, self.n - 1), value=0)\n h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device)\n for i in range(self.n):\n h = h * 257 + padded[:, i:i + T]\n return self.embed(h % self.num_buckets)\n\n\nclass ByteConvBlock(nn.Module):\n def __init__(self, d_model, kernel_size=15, expand=2):\n super().__init__()\n self.norm1 = nn.LayerNorm(d_model)\n self.pad = kernel_size - 1\n self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model)\n self.norm2 = nn.LayerNorm(d_model)\n ffn = d_model * expand\n self.ffn_gate = nn.Linear(d_model, ffn, bias=False)\n self.ffn_up = nn.Linear(d_model, ffn, bias=False)\n self.ffn_down = nn.Linear(ffn, d_model, bias=False)\n\n def forward(self, x):\n residual = x\n x = self.norm1(x).transpose(1, 2)\n x = F.pad(x, (self.pad, 0))\n x = F.silu(self.conv(x)).transpose(1, 2)\n x = residual + x\n residual = x\n x = self.norm2(x)\n return residual + self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))\n\n\ndef _rope(q, k):\n head_dim, seq_len = q.shape[-1], q.shape[-2]\n freqs = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2, device=q.device).float() / head_dim))\n a = torch.outer(torch.arange(seq_len, device=q.device), freqs)\n cos, sin = a.cos().to(q.dtype), a.sin().to(q.dtype)\n def rot(x):\n x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2:]\n return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)\n return rot(q), rot(k)\n\n\nclass ByteAttnBlock(nn.Module):\n def __init__(self, d_model, n_heads=4, expand=2):\n super().__init__()\n self.n_heads, self.head_dim = n_heads, d_model // n_heads\n self.norm1 = nn.LayerNorm(d_model)\n self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)\n self.out_proj = nn.Linear(d_model, d_model, bias=False)\n self.norm2 = nn.LayerNorm(d_model)\n ffn = d_model * expand\n self.ffn_gate = nn.Linear(d_model, ffn, bias=False)\n self.ffn_up = nn.Linear(d_model, ffn, bias=False)\n self.ffn_down = nn.Linear(ffn, d_model, bias=False)\n\n def forward(self, x):\n B, T, D = x.shape\n residual = x\n h = self.norm1(x)\n qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)\n q, k, v = (t.transpose(1, 2) for t in qkv.unbind(dim=2))\n q, k = _rope(q, k)\n out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=False)\n out = out.transpose(1, 2).contiguous().view(B, T, D)\n x = residual + self.out_proj(out)\n residual = x\n h = self.norm2(x)\n return residual + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))\n\n\nclass ByteHybrid(nn.Module):\n def __init__(self, num_classes, d_model=256, n_conv=3, n_attn=1, n_heads=4,\n ffn_expand=2, max_len=512, conv_kernel=15, ngram_buckets=4096, ngram_dim=64):\n super().__init__()\n self.max_len = max_len\n self.embed = nn.Embedding(257, d_model, padding_idx=256)\n self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=3) if ngram_buckets else None\n if self.ngram_embed is not None:\n self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False)\n self.conv_layers = nn.ModuleList([ByteConvBlock(d_model, conv_kernel, ffn_expand) for _ in range(n_conv)])\n self.attn_layers = nn.ModuleList([ByteAttnBlock(d_model, n_heads, ffn_expand) for _ in range(n_attn)])\n self.final_norm = nn.LayerNorm(d_model)\n self.head = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(0.1), nn.Linear(d_model, num_classes))\n\n def forward(self, byte_ids):\n pad_mask = byte_ids != 256\n x = self.embed(byte_ids)\n if self.ngram_embed is not None:\n x = x + self.ngram_proj(self.ngram_embed(byte_ids))\n for layer in self.conv_layers:\n x = layer(x)\n for layer in self.attn_layers:\n x = layer(x)\n x = self.final_norm(x)\n mask = pad_mask.unsqueeze(-1).to(x.dtype)\n x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)\n return self.head(x)\n" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load checkpoint from the Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "from huggingface_hub import hf_hub_download\nimport numpy as np\n\nREPO = 'FrameByFrame/programming-language-identification-100plus-lite'\nckpt_path = hf_hub_download(REPO, 'model.pt')\nckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)\n\nBASE_NGRAM = dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,\n ngram_buckets=4096, ngram_dim=64)\nmodel = ByteHybrid(num_classes=ckpt['num_classes'], max_len=ckpt['max_len'], **BASE_NGRAM).eval()\nmodel.load_state_dict(ckpt['model_state_dict'])\nidx2lang = {v: k for k, v in ckpt['lang2idx'].items()}\nMAX_LEN = ckpt['max_len']\nprint(f'{ckpt[\"num_classes\"]} labels | max_len={MAX_LEN} | params={sum(p.numel() for p in model.parameters())/1e6:.2f}M')" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "def encode(texts, max_len=MAX_LEN):\n out = np.full((len(texts), max_len), 256, dtype=np.int64)\n for i, t in enumerate(texts):\n b = t.encode('utf-8', errors='replace')[:max_len]\n out[i, :len(b)] = np.frombuffer(b, dtype=np.uint8)\n return torch.from_numpy(out)\n\n\n@torch.no_grad()\ndef predict(texts, top_k=3):\n probs = torch.softmax(model(encode(texts)).float(), dim=-1)\n top_p, top_i = probs.topk(top_k, dim=-1)\n return [[(idx2lang[int(j)], float(p)) for p, j in zip(pr, ix)]\n for pr, ix in zip(top_p.tolist(), top_i.tolist())]" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "samples = [\n \"def fib(n):\\n return n if n < 2 else fib(n-1) + fib(n-2)\",\n \"fn main() {\\n println!(\\\"hello, world\\\");\\n}\",\n \"package main\\nimport \\\"fmt\\\"\\nfunc main() { fmt.Println(\\\"hi\\\") }\",\n \"#include \\nint main() { printf(\\\"hi\\\\n\\\"); return 0; }\",\n \"SELECT name FROM users WHERE id = 42;\",\n]\nfor text, top in zip(samples, predict(samples)):\n print(f'{top[0][0]:<14s} {top[0][1]:.3f} ({top[1][0]} {top[1][1]:.2f}, {top[2][0]} {top[2][1]:.2f}) | {text[:60]!r}')" } ], "metadata": { "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python", "version": "3.11"} }, "nbformat": 4, "nbformat_minor": 5 }