File size: 8,499 Bytes
58b8e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `programming-language-identification-100plus-lite` — PyTorch\n",
    "\n",
    "2.35M-param byte-level classifier across 107 programming languages. No tokenizer; raw UTF-8 bytes padded to 1023.\n",
    "\n",
    "Self-contained: this notebook inlines the model definition (vendored from PleIAs/CommonLingua, Apache-2.0) and downloads the checkpoint from the Hub. Run end-to-end in Colab or Jupyter."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "%%capture\n!pip install -q -U torch huggingface_hub\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model definition (ByteHybrid — vendored from PleIAs/CommonLingua, Apache-2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ByteNgramEmbed(nn.Module):\n    def __init__(self, num_buckets=4096, embed_dim=64, n=3):\n        super().__init__()\n        self.n, self.num_buckets = n, num_buckets\n        self.embed = nn.Embedding(num_buckets, embed_dim)\n\n    def forward(self, byte_ids):\n        B, T = byte_ids.shape\n        clamped = byte_ids.clamp(max=255)\n        padded = F.pad(clamped, (0, self.n - 1), value=0)\n        h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device)\n        for i in range(self.n):\n            h = h * 257 + padded[:, i:i + T]\n        return self.embed(h % self.num_buckets)\n\n\nclass ByteConvBlock(nn.Module):\n    def __init__(self, d_model, kernel_size=15, expand=2):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(d_model)\n        self.pad = kernel_size - 1\n        self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n        ffn = d_model * expand\n        self.ffn_gate = nn.Linear(d_model, ffn, bias=False)\n        self.ffn_up = nn.Linear(d_model, ffn, bias=False)\n        self.ffn_down = nn.Linear(ffn, d_model, bias=False)\n\n    def forward(self, x):\n        residual = x\n        x = self.norm1(x).transpose(1, 2)\n        x = F.pad(x, (self.pad, 0))\n        x = F.silu(self.conv(x)).transpose(1, 2)\n        x = residual + x\n        residual = x\n        x = self.norm2(x)\n        return residual + self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x))\n\n\ndef _rope(q, k):\n    head_dim, seq_len = q.shape[-1], q.shape[-2]\n    freqs = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2, device=q.device).float() / head_dim))\n    a = torch.outer(torch.arange(seq_len, device=q.device), freqs)\n    cos, sin = a.cos().to(q.dtype), a.sin().to(q.dtype)\n    def rot(x):\n        x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2:]\n        return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)\n    return rot(q), rot(k)\n\n\nclass ByteAttnBlock(nn.Module):\n    def __init__(self, d_model, n_heads=4, expand=2):\n        super().__init__()\n        self.n_heads, self.head_dim = n_heads, d_model // n_heads\n        self.norm1 = nn.LayerNorm(d_model)\n        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)\n        self.out_proj = nn.Linear(d_model, d_model, bias=False)\n        self.norm2 = nn.LayerNorm(d_model)\n        ffn = d_model * expand\n        self.ffn_gate = nn.Linear(d_model, ffn, bias=False)\n        self.ffn_up = nn.Linear(d_model, ffn, bias=False)\n        self.ffn_down = nn.Linear(ffn, d_model, bias=False)\n\n    def forward(self, x):\n        B, T, D = x.shape\n        residual = x\n        h = self.norm1(x)\n        qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim)\n        q, k, v = (t.transpose(1, 2) for t in qkv.unbind(dim=2))\n        q, k = _rope(q, k)\n        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=False)\n        out = out.transpose(1, 2).contiguous().view(B, T, D)\n        x = residual + self.out_proj(out)\n        residual = x\n        h = self.norm2(x)\n        return residual + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))\n\n\nclass ByteHybrid(nn.Module):\n    def __init__(self, num_classes, d_model=256, n_conv=3, n_attn=1, n_heads=4,\n                 ffn_expand=2, max_len=512, conv_kernel=15, ngram_buckets=4096, ngram_dim=64):\n        super().__init__()\n        self.max_len = max_len\n        self.embed = nn.Embedding(257, d_model, padding_idx=256)\n        self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=3) if ngram_buckets else None\n        if self.ngram_embed is not None:\n            self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False)\n        self.conv_layers = nn.ModuleList([ByteConvBlock(d_model, conv_kernel, ffn_expand) for _ in range(n_conv)])\n        self.attn_layers = nn.ModuleList([ByteAttnBlock(d_model, n_heads, ffn_expand) for _ in range(n_attn)])\n        self.final_norm = nn.LayerNorm(d_model)\n        self.head = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(0.1), nn.Linear(d_model, num_classes))\n\n    def forward(self, byte_ids):\n        pad_mask = byte_ids != 256\n        x = self.embed(byte_ids)\n        if self.ngram_embed is not None:\n            x = x + self.ngram_proj(self.ngram_embed(byte_ids))\n        for layer in self.conv_layers:\n            x = layer(x)\n        for layer in self.attn_layers:\n            x = layer(x)\n        x = self.final_norm(x)\n        mask = pad_mask.unsqueeze(-1).to(x.dtype)\n        x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)\n        return self.head(x)\n"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load checkpoint from the Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from huggingface_hub import hf_hub_download\nimport numpy as np\n\nREPO = 'FrameByFrame/programming-language-identification-100plus-lite'\nckpt_path = hf_hub_download(REPO, 'model.pt')\nckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)\n\nBASE_NGRAM = dict(d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15,\n                  ngram_buckets=4096, ngram_dim=64)\nmodel = ByteHybrid(num_classes=ckpt['num_classes'], max_len=ckpt['max_len'], **BASE_NGRAM).eval()\nmodel.load_state_dict(ckpt['model_state_dict'])\nidx2lang = {v: k for k, v in ckpt['lang2idx'].items()}\nMAX_LEN = ckpt['max_len']\nprint(f'{ckpt[\"num_classes\"]} labels | max_len={MAX_LEN} | params={sum(p.numel() for p in model.parameters())/1e6:.2f}M')"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "def encode(texts, max_len=MAX_LEN):\n    out = np.full((len(texts), max_len), 256, dtype=np.int64)\n    for i, t in enumerate(texts):\n        b = t.encode('utf-8', errors='replace')[:max_len]\n        out[i, :len(b)] = np.frombuffer(b, dtype=np.uint8)\n    return torch.from_numpy(out)\n\n\n@torch.no_grad()\ndef predict(texts, top_k=3):\n    probs = torch.softmax(model(encode(texts)).float(), dim=-1)\n    top_p, top_i = probs.topk(top_k, dim=-1)\n    return [[(idx2lang[int(j)], float(p)) for p, j in zip(pr, ix)]\n            for pr, ix in zip(top_p.tolist(), top_i.tolist())]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "samples = [\n    \"def fib(n):\\n    return n if n < 2 else fib(n-1) + fib(n-2)\",\n    \"fn main() {\\n    println!(\\\"hello, world\\\");\\n}\",\n    \"package main\\nimport \\\"fmt\\\"\\nfunc main() { fmt.Println(\\\"hi\\\") }\",\n    \"#include <stdio.h>\\nint main() { printf(\\\"hi\\\\n\\\"); return 0; }\",\n    \"SELECT name FROM users WHERE id = 42;\",\n]\nfor text, top in zip(samples, predict(samples)):\n    print(f'{top[0][0]:<14s}  {top[0][1]:.3f}   ({top[1][0]} {top[1][1]:.2f}, {top[2][0]} {top[2][1]:.2f})  | {text[:60]!r}')"
  }
 ],
 "metadata": {
  "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
  "language_info": {"name": "python", "version": "3.11"}
 },
 "nbformat": 4,
 "nbformat_minor": 5
}