Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import jax
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
import flax.linen as nn
|
| 5 |
+
import pickle
|
| 6 |
+
import numpy as np
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
HF_REPO = "MoEprometheus/Prometheus-base"
|
| 11 |
+
|
| 12 |
+
print("📥 Загружаем Prometheus...")
|
| 13 |
+
path = hf_hub_download(HF_REPO, "expert1.pkl")
|
| 14 |
+
with open(path, "rb") as f:
|
| 15 |
+
ckpt = pickle.load(f)
|
| 16 |
+
|
| 17 |
+
itos = ckpt["vocab"]
|
| 18 |
+
stoi = {v: k for k, v in itos.items()}
|
| 19 |
+
CONFIG = ckpt["config"]
|
| 20 |
+
encode = lambda s: [stoi.get(c, 0) for c in s]
|
| 21 |
+
decode = lambda l: "".join([itos.get(i, "") for i in l])
|
| 22 |
+
print(f"✅ Загружено — шаг {ckpt['step']}")
|
| 23 |
+
|
| 24 |
+
class PrometheusAttention(nn.Module):
|
| 25 |
+
n_heads: int
|
| 26 |
+
n_embed: int
|
| 27 |
+
block_size: int
|
| 28 |
+
dropout: float
|
| 29 |
+
@nn.compact
|
| 30 |
+
def __call__(self, x, deterministic=True):
|
| 31 |
+
B, T, C = x.shape
|
| 32 |
+
head_size = self.n_embed // self.n_heads
|
| 33 |
+
qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x)
|
| 34 |
+
q, k, v = jnp.split(qkv, 3, axis=-1)
|
| 35 |
+
q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
|
| 36 |
+
k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
|
| 37 |
+
v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
|
| 38 |
+
att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5)
|
| 39 |
+
mask = jnp.tril(jnp.ones((T, T)))
|
| 40 |
+
att = jnp.where(mask == 0, -1e9, att)
|
| 41 |
+
att = jax.nn.softmax(att, axis=-1)
|
| 42 |
+
att = nn.Dropout(self.dropout)(att, deterministic=deterministic)
|
| 43 |
+
out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
|
| 44 |
+
out = nn.Dense(self.n_embed)(out)
|
| 45 |
+
return nn.Dropout(self.dropout)(out, deterministic=deterministic)
|
| 46 |
+
|
| 47 |
+
class PrometheusMLP(nn.Module):
|
| 48 |
+
n_embed: int
|
| 49 |
+
dropout: float
|
| 50 |
+
@nn.compact
|
| 51 |
+
def __call__(self, x, deterministic=True):
|
| 52 |
+
x = nn.Dense(4 * self.n_embed)(x)
|
| 53 |
+
x = nn.gelu(x)
|
| 54 |
+
x = nn.Dense(self.n_embed)(x)
|
| 55 |
+
return nn.Dropout(self.dropout)(x, deterministic=deterministic)
|
| 56 |
+
|
| 57 |
+
class PrometheusBlock(nn.Module):
|
| 58 |
+
n_embed: int
|
| 59 |
+
n_heads: int
|
| 60 |
+
block_size: int
|
| 61 |
+
dropout: float
|
| 62 |
+
@nn.compact
|
| 63 |
+
def __call__(self, x, deterministic=True):
|
| 64 |
+
x = x + PrometheusAttention(
|
| 65 |
+
self.n_heads, self.n_embed,
|
| 66 |
+
self.block_size, self.dropout
|
| 67 |
+
)(nn.LayerNorm()(x), deterministic)
|
| 68 |
+
x = x + PrometheusMLP(
|
| 69 |
+
self.n_embed, self.dropout
|
| 70 |
+
)(nn.LayerNorm()(x), deterministic)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
class Prometheus(nn.Module):
|
| 74 |
+
vocab_size: int
|
| 75 |
+
n_embed: int
|
| 76 |
+
n_heads: int
|
| 77 |
+
n_layers: int
|
| 78 |
+
block_size: int
|
| 79 |
+
dropout: float
|
| 80 |
+
@nn.compact
|
| 81 |
+
def __call__(self, idx, training=False):
|
| 82 |
+
B, T = idx.shape
|
| 83 |
+
tok = nn.Embed(self.vocab_size, self.n_embed)(idx)
|
| 84 |
+
pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T))
|
| 85 |
+
x = nn.Dropout(self.dropout)(tok + pos, deterministic=True)
|
| 86 |
+
BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,))
|
| 87 |
+
for _ in range(self.n_layers):
|
| 88 |
+
x = BlockRemat(
|
| 89 |
+
self.n_embed, self.n_heads,
|
| 90 |
+
self.block_size, self.dropout
|
| 91 |
+
)(x, True)
|
| 92 |
+
return nn.Dense(self.vocab_size)(nn.LayerNorm()(x))
|
| 93 |
+
|
| 94 |
+
model = Prometheus(
|
| 95 |
+
vocab_size = CONFIG["vocab_size"],
|
| 96 |
+
n_embed = CONFIG["n_embed"],
|
| 97 |
+
n_heads = CONFIG["n_heads"],
|
| 98 |
+
n_layers = CONFIG["n_layers"],
|
| 99 |
+
block_size = CONFIG["block_size"],
|
| 100 |
+
dropout = CONFIG["dropout"],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
params = ckpt["params"]
|
| 104 |
+
|
| 105 |
+
def generate(prompt, max_new_tokens=80, temperature=1.1):
|
| 106 |
+
tokens = encode(prompt)
|
| 107 |
+
tokens = tokens[-(CONFIG["block_size"]-1):]
|
| 108 |
+
for _ in range(max_new_tokens):
|
| 109 |
+
x = jnp.array([tokens])
|
| 110 |
+
logits = model.apply(params, x, training=False)
|
| 111 |
+
logits = logits[0, -1, :] / temperature
|
| 112 |
+
top_k = 40
|
| 113 |
+
top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
|
| 114 |
+
probs = jax.nn.softmax(top_k_logits)
|
| 115 |
+
chosen = int(jax.random.categorical(
|
| 116 |
+
jax.random.PRNGKey(np.random.randint(0, 99999)),
|
| 117 |
+
jnp.log(probs)
|
| 118 |
+
))
|
| 119 |
+
next_token = int(top_k_indices[chosen])
|
| 120 |
+
tokens.append(next_token)
|
| 121 |
+
return decode(tokens)
|
| 122 |
+
|
| 123 |
+
def chat(message, history):
|
| 124 |
+
result = generate(message, max_new_tokens=80, temperature=1.1)
|
| 125 |
+
# Убираем промпт из ответа
|
| 126 |
+
if len(result) > len(message):
|
| 127 |
+
answer = result[len(message):]
|
| 128 |
+
else:
|
| 129 |
+
answer = result
|
| 130 |
+
return answer.strip()
|
| 131 |
+
|
| 132 |
+
demo = gr.ChatInterface(
|
| 133 |
+
fn=chat,
|
| 134 |
+
title="🔥 Prometheus AI",
|
| 135 |
+
description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.",
|
| 136 |
+
examples=[
|
| 137 |
+
"Москва —",
|
| 138 |
+
"Россия — это",
|
| 139 |
+
"Нейронная сеть — это",
|
| 140 |
+
"Python — язык",
|
| 141 |
+
],
|
| 142 |
+
theme="soft",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
demo.launch()
|