MoEprometheus commited on
Commit
7bf134f
·
verified ·
1 Parent(s): 6923a28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import flax.linen as nn
5
+ import pickle
6
+ import numpy as np
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ HF_REPO = "MoEprometheus/Prometheus-base"
11
+
12
+ print("📥 Загружаем Prometheus...")
13
+ path = hf_hub_download(HF_REPO, "expert1.pkl")
14
+ with open(path, "rb") as f:
15
+ ckpt = pickle.load(f)
16
+
17
+ itos = ckpt["vocab"]
18
+ stoi = {v: k for k, v in itos.items()}
19
+ CONFIG = ckpt["config"]
20
+ encode = lambda s: [stoi.get(c, 0) for c in s]
21
+ decode = lambda l: "".join([itos.get(i, "") for i in l])
22
+ print(f"✅ Загружено — шаг {ckpt['step']}")
23
+
24
+ class PrometheusAttention(nn.Module):
25
+ n_heads: int
26
+ n_embed: int
27
+ block_size: int
28
+ dropout: float
29
+ @nn.compact
30
+ def __call__(self, x, deterministic=True):
31
+ B, T, C = x.shape
32
+ head_size = self.n_embed // self.n_heads
33
+ qkv = nn.Dense(3 * self.n_embed, use_bias=False)(x)
34
+ q, k, v = jnp.split(qkv, 3, axis=-1)
35
+ q = q.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
36
+ k = k.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
37
+ v = v.reshape(B, T, self.n_heads, head_size).transpose(0, 2, 1, 3)
38
+ att = (q @ k.transpose(0, 1, 3, 2)) * (head_size ** -0.5)
39
+ mask = jnp.tril(jnp.ones((T, T)))
40
+ att = jnp.where(mask == 0, -1e9, att)
41
+ att = jax.nn.softmax(att, axis=-1)
42
+ att = nn.Dropout(self.dropout)(att, deterministic=deterministic)
43
+ out = (att @ v).transpose(0, 2, 1, 3).reshape(B, T, C)
44
+ out = nn.Dense(self.n_embed)(out)
45
+ return nn.Dropout(self.dropout)(out, deterministic=deterministic)
46
+
47
+ class PrometheusMLP(nn.Module):
48
+ n_embed: int
49
+ dropout: float
50
+ @nn.compact
51
+ def __call__(self, x, deterministic=True):
52
+ x = nn.Dense(4 * self.n_embed)(x)
53
+ x = nn.gelu(x)
54
+ x = nn.Dense(self.n_embed)(x)
55
+ return nn.Dropout(self.dropout)(x, deterministic=deterministic)
56
+
57
+ class PrometheusBlock(nn.Module):
58
+ n_embed: int
59
+ n_heads: int
60
+ block_size: int
61
+ dropout: float
62
+ @nn.compact
63
+ def __call__(self, x, deterministic=True):
64
+ x = x + PrometheusAttention(
65
+ self.n_heads, self.n_embed,
66
+ self.block_size, self.dropout
67
+ )(nn.LayerNorm()(x), deterministic)
68
+ x = x + PrometheusMLP(
69
+ self.n_embed, self.dropout
70
+ )(nn.LayerNorm()(x), deterministic)
71
+ return x
72
+
73
+ class Prometheus(nn.Module):
74
+ vocab_size: int
75
+ n_embed: int
76
+ n_heads: int
77
+ n_layers: int
78
+ block_size: int
79
+ dropout: float
80
+ @nn.compact
81
+ def __call__(self, idx, training=False):
82
+ B, T = idx.shape
83
+ tok = nn.Embed(self.vocab_size, self.n_embed)(idx)
84
+ pos = nn.Embed(self.block_size, self.n_embed)(jnp.arange(T))
85
+ x = nn.Dropout(self.dropout)(tok + pos, deterministic=True)
86
+ BlockRemat = nn.remat(PrometheusBlock, static_argnums=(2,))
87
+ for _ in range(self.n_layers):
88
+ x = BlockRemat(
89
+ self.n_embed, self.n_heads,
90
+ self.block_size, self.dropout
91
+ )(x, True)
92
+ return nn.Dense(self.vocab_size)(nn.LayerNorm()(x))
93
+
94
+ model = Prometheus(
95
+ vocab_size = CONFIG["vocab_size"],
96
+ n_embed = CONFIG["n_embed"],
97
+ n_heads = CONFIG["n_heads"],
98
+ n_layers = CONFIG["n_layers"],
99
+ block_size = CONFIG["block_size"],
100
+ dropout = CONFIG["dropout"],
101
+ )
102
+
103
+ params = ckpt["params"]
104
+
105
+ def generate(prompt, max_new_tokens=80, temperature=1.1):
106
+ tokens = encode(prompt)
107
+ tokens = tokens[-(CONFIG["block_size"]-1):]
108
+ for _ in range(max_new_tokens):
109
+ x = jnp.array([tokens])
110
+ logits = model.apply(params, x, training=False)
111
+ logits = logits[0, -1, :] / temperature
112
+ top_k = 40
113
+ top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
114
+ probs = jax.nn.softmax(top_k_logits)
115
+ chosen = int(jax.random.categorical(
116
+ jax.random.PRNGKey(np.random.randint(0, 99999)),
117
+ jnp.log(probs)
118
+ ))
119
+ next_token = int(top_k_indices[chosen])
120
+ tokens.append(next_token)
121
+ return decode(tokens)
122
+
123
+ def chat(message, history):
124
+ result = generate(message, max_new_tokens=80, temperature=1.1)
125
+ # Убираем промпт из ответа
126
+ if len(result) > len(message):
127
+ answer = result[len(message):]
128
+ else:
129
+ answer = result
130
+ return answer.strip()
131
+
132
+ demo = gr.ChatInterface(
133
+ fn=chat,
134
+ title="🔥 Prometheus AI",
135
+ description="Языковая модель 1.2B параметров. Создана с нуля одним человеком.",
136
+ examples=[
137
+ "Москва —",
138
+ "Россия — это",
139
+ "Нейронная сеть — это",
140
+ "Python — язык",
141
+ ],
142
+ theme="soft",
143
+ )
144
+
145
+ demo.launch()