Wilsonwin commited on
Commit
999f17d
·
1 Parent(s): 1b7e8af

Add Mini-GPT Gradio app

Browse files
Files changed (3) hide show
  1. README.md +20 -9
  2. app.py +261 -53
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,16 +1,27 @@
1
  ---
2
- title: Mini Gpt Demo
3
- emoji: 💬
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
- license: mit
14
  ---
15
 
16
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Mini-GPT 文本生成
3
+ emoji: 🤖
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
 
 
 
11
  ---
12
 
13
+ # Mini-GPT 文本生成
14
+
15
+ 使用 JAX/Flax 在 Kaggle TPU 上训练的小型 GPT 模型。
16
+
17
+ ## 功能
18
+
19
+ - 支持中英文文本生成
20
+ - 可调节生成长度和温度参数
21
+
22
+ ## 模型信息
23
+
24
+ - **架构**: GPT-2 style transformer
25
+ - **参数量**: ~25M
26
+ - **训练框架**: JAX/Flax
27
+ - **训练硬件**: Kaggle TPU v3-8
app.py CHANGED
@@ -1,70 +1,278 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ """
2
+ HuggingFace Spaces Gradio App for Mini-GPT
3
+ 上传到 HuggingFace Spaces 即可部署
4
+ """
5
+
6
  import gradio as gr
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import flax.linen as nn
10
+ from huggingface_hub import hf_hub_download
11
+ import orbax.checkpoint as ocp
12
+ from typing import List, Optional, Union
13
+ import os
14
 
15
+ # ============================================================================
16
+ # 模型定义 (与训练时保持一致)
17
+ # ============================================================================
18
 
19
+ class TokenAndPositionEmbedding(nn.Module):
20
+ vocab_size: int
21
+ max_len: int
22
+ embed_dim: int
 
 
 
 
 
 
 
 
 
23
 
24
+ @nn.compact
25
+ def __call__(self, x):
26
+ seq_len = x.shape[1]
27
+ positions = jnp.arange(seq_len)
28
+ tok_emb = nn.Embed(self.vocab_size, self.embed_dim, name='token_emb')(x)
29
+ pos_emb = nn.Embed(self.max_len, self.embed_dim, name='pos_emb')(positions)
30
+ return tok_emb + pos_emb
31
 
 
32
 
33
+ class TransformerBlock(nn.Module):
34
+ embed_dim: int
35
+ num_heads: int
36
+ ff_dim: int
37
+ dropout_rate: float = 0.1
38
 
39
+ @nn.compact
40
+ def __call__(self, x, training: bool = False):
41
+ attn_output = nn.SelfAttention(
42
+ num_heads=self.num_heads,
43
+ qkv_features=self.embed_dim,
44
+ dropout_rate=self.dropout_rate,
45
+ deterministic=True, # 推理时不使用 dropout
46
+ decode=False,
47
+ )(x, mask=nn.make_causal_mask(jnp.ones((x.shape[0], x.shape[1]))))
48
+
49
+ x = nn.LayerNorm()(x + attn_output)
50
+
51
+ ffn_output = nn.Dense(self.ff_dim)(x)
52
+ ffn_output = nn.gelu(ffn_output)
53
+ ffn_output = nn.Dense(self.embed_dim)(ffn_output)
54
+
55
+ x = nn.LayerNorm()(x + ffn_output)
56
+ return x
57
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ class MiniGPT(nn.Module):
60
+ vocab_size: int
61
+ max_len: int
62
+ embed_dim: int
63
+ num_heads: int
64
+ num_layers: int
65
+ ff_dim: int
66
+ dropout_rate: float = 0.1
67
 
68
+ @nn.compact
69
+ def __call__(self, x, training: bool = False):
70
+ x = TokenAndPositionEmbedding(
71
+ vocab_size=self.vocab_size,
72
+ max_len=self.max_len,
73
+ embed_dim=self.embed_dim
74
+ )(x)
75
+
76
+ for i in range(self.num_layers):
77
+ x = TransformerBlock(
78
+ embed_dim=self.embed_dim,
79
+ num_heads=self.num_heads,
80
+ ff_dim=self.ff_dim,
81
+ dropout_rate=self.dropout_rate,
82
+ name=f'transformer_block_{i}'
83
+ )(x, training=training)
84
+
85
+ logits = nn.Dense(self.vocab_size, name='lm_head')(x)
86
+ return logits
87
 
88
+
89
+ # ============================================================================
90
+ # Tokenizer (Yi-1.5)
91
+ # ============================================================================
92
+
93
+ class MultilingualTokenizer:
94
+ def __init__(self, model_name: str = "01-ai/Yi-1.5-6B"):
95
+ from transformers import AutoTokenizer
96
+
97
+ self._tokenizer = AutoTokenizer.from_pretrained(
98
+ model_name,
99
+ trust_remote_code=True,
100
+ use_fast=True
101
+ )
102
+
103
+ self._eot_token = self._tokenizer.eos_token_id
104
+ self._pad_token = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else 0
105
+
106
+ raw_vocab = len(self._tokenizer)
107
+ self._padded_vocab = ((raw_vocab // 128) + 1) * 128 if raw_vocab % 128 != 0 else raw_vocab
108
+
109
+ @property
110
+ def padded_vocab_size(self) -> int:
111
+ return self._padded_vocab
112
+
113
+ @property
114
+ def eot_token(self) -> int:
115
+ return self._eot_token
116
+
117
+ def encode(self, text: str) -> List[int]:
118
+ return self._tokenizer.encode(text, add_special_tokens=False)
119
+
120
+ def decode(self, tokens) -> str:
121
+ if isinstance(tokens, int):
122
+ tokens = [tokens]
123
+ return self._tokenizer.decode(tokens, skip_special_tokens=True)
124
+
125
+
126
+ # ============================================================================
127
+ # 模型配置 (必须与训练时一致!)
128
+ # ============================================================================
129
+
130
+ CONFIG = {
131
+ "max_len": 256,
132
+ "embed_dim": 512,
133
+ "num_heads": 8,
134
+ "num_layers": 6,
135
+ "ff_dim": 2048,
136
+ "dropout_rate": 0.1,
137
+ }
138
+
139
+ REPO_ID = "Wilsonwin/handsongpt2" # 你的 HuggingFace 仓库
140
+
141
+
142
+ # ============================================================================
143
+ # 加载模型
144
+ # ============================================================================
145
+
146
+ print("Loading tokenizer...")
147
+ tokenizer = MultilingualTokenizer()
148
+ CONFIG["vocab_size"] = tokenizer.padded_vocab_size
149
+
150
+ print("Creating model...")
151
+ model = MiniGPT(
152
+ vocab_size=CONFIG["vocab_size"],
153
+ max_len=CONFIG["max_len"],
154
+ embed_dim=CONFIG["embed_dim"],
155
+ num_heads=CONFIG["num_heads"],
156
+ num_layers=CONFIG["num_layers"],
157
+ ff_dim=CONFIG["ff_dim"],
158
+ dropout_rate=CONFIG["dropout_rate"]
159
+ )
160
+
161
+ print("Downloading checkpoint from HuggingFace...")
162
+ checkpoint_path = hf_hub_download(
163
+ repo_id=REPO_ID,
164
+ filename="checkpoint",
165
+ repo_type="model",
166
+ local_dir="./checkpoint_dir"
167
  )
168
 
169
+ print(f"Loading checkpoint from {checkpoint_path}...")
170
+ checkpointer = ocp.PyTreeCheckpointer()
171
+ state = checkpointer.restore(checkpoint_path)
172
+ params = state['params']
173
+
174
+ print("✓ Model loaded successfully!")
175
+
176
+
177
+ # ============================================================================
178
+ # 文本生成函数
179
+ # ============================================================================
180
+
181
+ def generate_text(prompt: str, max_new_tokens: int = 50, temperature: float = 1.0) -> str:
182
+ """生成文本"""
183
+ input_ids = jnp.array([tokenizer.encode(prompt)], dtype=jnp.int32)
184
+
185
+ for _ in range(max_new_tokens):
186
+ if input_ids.shape[1] >= CONFIG["max_len"]:
187
+ input_ids = input_ids[:, -CONFIG["max_len"]:]
188
+
189
+ logits = model.apply({'params': params}, input_ids, training=False)
190
+ next_token_logits = logits[0, -1, :] / max(temperature, 0.1)
191
+
192
+ # 贪婪采样
193
+ next_token = jnp.argmax(next_token_logits)
194
+
195
+ input_ids = jnp.concatenate([input_ids, next_token[None, None]], axis=1)
196
+
197
+ if next_token == tokenizer.eot_token:
198
+ break
199
+
200
+ return tokenizer.decode(input_ids[0].tolist())
201
+
202
+
203
+ # ============================================================================
204
+ # Gradio 界面
205
+ # ============================================================================
206
+
207
+ def gradio_generate(prompt, max_tokens, temperature):
208
+ """Gradio 回调函数"""
209
+ if not prompt.strip():
210
+ return "请输入提示词..."
211
+
212
+ result = generate_text(prompt, int(max_tokens), float(temperature))
213
+ return result
214
+
215
 
216
+ # 创建界面
217
+ with gr.Blocks(title="Mini-GPT 文本生成", theme=gr.themes.Soft()) as demo:
218
+ gr.Markdown("""
219
+ # 🤖 Mini-GPT 文本生成
220
+
221
+ 使用 JAX/Flax 在 Kaggle TPU 上训练的小型 GPT 模型。
222
+
223
+ 支持中英文输入。
224
+ """)
225
+
226
+ with gr.Row():
227
+ with gr.Column(scale=2):
228
+ prompt_input = gr.Textbox(
229
+ label="输入提示词",
230
+ placeholder="例如: 从前有一个...",
231
+ lines=3
232
+ )
233
+
234
+ with gr.Row():
235
+ max_tokens = gr.Slider(
236
+ minimum=10,
237
+ maximum=100,
238
+ value=50,
239
+ step=10,
240
+ label="最大生成长度"
241
+ )
242
+ temperature = gr.Slider(
243
+ minimum=0.1,
244
+ maximum=2.0,
245
+ value=1.0,
246
+ step=0.1,
247
+ label="温度 (越高越随机)"
248
+ )
249
+
250
+ generate_btn = gr.Button("🚀 生成", variant="primary")
251
+
252
+ with gr.Column(scale=2):
253
+ output = gr.Textbox(
254
+ label="生成结果",
255
+ lines=8,
256
+ interactive=False
257
+ )
258
+
259
+ # 示例
260
+ gr.Examples(
261
+ examples=[
262
+ ["这是", 50, 1.0],
263
+ ["Hello", 50, 1.0],
264
+ ["从前有一个", 80, 0.8],
265
+ ["The quick brown", 50, 1.0],
266
+ ],
267
+ inputs=[prompt_input, max_tokens, temperature],
268
+ )
269
+
270
+ generate_btn.click(
271
+ fn=gradio_generate,
272
+ inputs=[prompt_input, max_tokens, temperature],
273
+ outputs=output
274
+ )
275
 
276
+ # 启动
277
  if __name__ == "__main__":
278
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ jax[cpu]
3
+ flax
4
+ orbax-checkpoint
5
+ transformers
6
+ huggingface_hub