Wilsonwin commited on
Commit
d74f76b
·
1 Parent(s): 6d8f48f

Simplified demo with random params - checkpoint loading fix pending

Browse files
Files changed (2) hide show
  1. app.py +14 -74
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,21 +1,16 @@
1
  """
2
  HuggingFace Spaces Gradio App for Mini-GPT
3
- 上传到 HuggingFace Spaces 即可部署
4
  """
5
 
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
9
  import flax.linen as nn
10
- from huggingface_hub import snapshot_download
11
- import orbax.checkpoint as ocp
12
  from typing import List
13
- import os
14
- import json
15
- import shutil
16
 
17
  # ============================================================================
18
- # 模型定义 (与训练时保持一致)
19
  # ============================================================================
20
 
21
  class TokenAndPositionEmbedding(nn.Module):
@@ -128,11 +123,9 @@ CONFIG = {
128
  "dropout_rate": 0.1,
129
  }
130
 
131
- REPO_ID = "Wilsonwin/handsongpt2"
132
-
133
 
134
  # ============================================================================
135
- # 加载模型 (带 TPU->CPU sharding 修补)
136
  # ============================================================================
137
 
138
  print("Loading tokenizer...")
@@ -140,70 +133,15 @@ tokenizer = MultilingualTokenizer()
140
  CONFIG["vocab_size"] = tokenizer.padded_vocab_size
141
 
142
  print("Creating model...")
143
- model = MiniGPT(
144
- vocab_size=CONFIG["vocab_size"],
145
- max_len=CONFIG["max_len"],
146
- embed_dim=CONFIG["embed_dim"],
147
- num_heads=CONFIG["num_heads"],
148
- num_layers=CONFIG["num_layers"],
149
- ff_dim=CONFIG["ff_dim"],
150
- dropout_rate=CONFIG["dropout_rate"]
151
- )
152
-
153
- print("Downloading checkpoint from HuggingFace...")
154
- checkpoint_dir = snapshot_download(
155
- repo_id=REPO_ID,
156
- repo_type="model",
157
- allow_patterns=["checkpoint/*"]
158
- )
159
- checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
160
- print(f"Downloaded to: {checkpoint_path}")
161
-
162
- print("Patching sharding metadata (TPU -> CPU)...")
163
- # 创建临时副本来修改 sharding 信息
164
- patched_checkpoint_path = "/tmp/mini_gpt_checkpoint_patched"
165
- if os.path.exists(patched_checkpoint_path):
166
- shutil.rmtree(patched_checkpoint_path)
167
- shutil.copytree(checkpoint_path, patched_checkpoint_path, dirs_exist_ok=True)
168
-
169
- # 修改 _sharding 文件
170
- sharding_path = os.path.join(patched_checkpoint_path, "_sharding")
171
- if os.path.exists(sharding_path):
172
- with open(sharding_path, 'r') as f:
173
- sharding_data = json.load(f)
174
-
175
- cpu_device = jax.devices('cpu')[0]
176
- cpu_device_str = str(cpu_device)
177
-
178
- new_sharding = {}
179
- for key, value in sharding_data.items():
180
- value_dict = json.loads(value)
181
- value_dict['device_str'] = cpu_device_str
182
- new_sharding[key] = json.dumps(value_dict)
183
-
184
- with open(sharding_path, 'w') as f:
185
- json.dump(new_sharding, f)
186
-
187
- print(f"✓ Patched {len(new_sharding)} sharding entries to use CPU")
188
 
189
- print("Loading checkpoint...")
190
- checkpointer = ocp.PyTreeCheckpointer()
 
 
191
 
192
- try:
193
- state = checkpointer.restore(patched_checkpoint_path)
194
- if 'params' in state:
195
- params = state['params']
196
- print("✓ Model loaded successfully!")
197
- else:
198
- raise ValueError("params not found in checkpoint")
199
- except Exception as e:
200
- print(f"Error: {e}")
201
- print("Using randomly initialized parameters")
202
- rng = jax.random.PRNGKey(0)
203
- dummy_input = jnp.ones((1, CONFIG["max_len"]), dtype=jnp.int32)
204
- params = model.init(rng, dummy_input, training=False)['params']
205
-
206
- print(f"Params: {sum(x.size for x in jax.tree.leaves(params)):,} parameters")
207
 
208
 
209
  # ============================================================================
@@ -240,9 +178,11 @@ def gradio_generate(prompt, max_tokens, temperature):
240
 
241
  with gr.Blocks(title="Mini-GPT 文本生成", theme=gr.themes.Soft()) as demo:
242
  gr.Markdown("""
243
- # 🤖 Mini-GPT 文本生成
 
 
244
 
245
- 使用 JAX/Flax 在 Kaggle TPU 上训练的GPT 模型。支持英文输入
246
  """)
247
 
248
  with gr.Row():
 
1
  """
2
  HuggingFace Spaces Gradio App for Mini-GPT
3
+ 使用随机初始化参数的 Demo 版本
4
  """
5
 
6
  import gradio as gr
7
  import jax
8
  import jax.numpy as jnp
9
  import flax.linen as nn
 
 
10
  from typing import List
 
 
 
11
 
12
  # ============================================================================
13
+ # 模型定义
14
  # ============================================================================
15
 
16
  class TokenAndPositionEmbedding(nn.Module):
 
123
  "dropout_rate": 0.1,
124
  }
125
 
 
 
126
 
127
  # ============================================================================
128
+ # 加载模型 (随机初始化)
129
  # ============================================================================
130
 
131
  print("Loading tokenizer...")
 
133
  CONFIG["vocab_size"] = tokenizer.padded_vocab_size
134
 
135
  print("Creating model...")
136
+ model = MiniGPT(**CONFIG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ print("Initializing random parameters...")
139
+ rng = jax.random.PRNGKey(42)
140
+ dummy_input = jnp.ones((1, CONFIG["max_len"]), dtype=jnp.int32)
141
+ params = model.init(rng, dummy_input, training=False)['params']
142
 
143
+ print(f"✓ Model ready with {sum(x.size for x in jax.tree.leaves(params)):,} parameters")
144
+ print("⚠️ Note: Using random parameters (trained weights pending checkpoint fix)")
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
 
147
  # ============================================================================
 
178
 
179
  with gr.Blocks(title="Mini-GPT 文本生成", theme=gr.themes.Soft()) as demo:
180
  gr.Markdown("""
181
+ # 🤖 Mini-GPT 文本生成 (Demo)
182
+
183
+ 使用 JAX/Flax 构建的小型 GPT 模型。
184
 
185
+ ⚠️ **当前使用随机初始化参数** - 训练checkpoint 正在修复中。
186
  """)
187
 
188
  with gr.Row():
requirements.txt CHANGED
@@ -2,7 +2,6 @@ gradio==4.44.0
2
  jax==0.4.35
3
  jaxlib==0.4.35
4
  flax==0.10.2
5
- orbax-checkpoint==0.10.2
6
  transformers==4.47.0
7
  huggingface_hub>=0.23.0
8
  numpy==1.26.4
 
2
  jax==0.4.35
3
  jaxlib==0.4.35
4
  flax==0.10.2
 
5
  transformers==4.47.0
6
  huggingface_hub>=0.23.0
7
  numpy==1.26.4