Spaces:
Sleeping
Sleeping
Simplified demo with random params - checkpoint loading fix pending
Browse files- app.py +14 -74
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -1,21 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
HuggingFace Spaces Gradio App for Mini-GPT
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import jax
|
| 8 |
import jax.numpy as jnp
|
| 9 |
import flax.linen as nn
|
| 10 |
-
from huggingface_hub import snapshot_download
|
| 11 |
-
import orbax.checkpoint as ocp
|
| 12 |
from typing import List
|
| 13 |
-
import os
|
| 14 |
-
import json
|
| 15 |
-
import shutil
|
| 16 |
|
| 17 |
# ============================================================================
|
| 18 |
-
# 模型定义
|
| 19 |
# ============================================================================
|
| 20 |
|
| 21 |
class TokenAndPositionEmbedding(nn.Module):
|
|
@@ -128,11 +123,9 @@ CONFIG = {
|
|
| 128 |
"dropout_rate": 0.1,
|
| 129 |
}
|
| 130 |
|
| 131 |
-
REPO_ID = "Wilsonwin/handsongpt2"
|
| 132 |
-
|
| 133 |
|
| 134 |
# ============================================================================
|
| 135 |
-
# 加载模型 (
|
| 136 |
# ============================================================================
|
| 137 |
|
| 138 |
print("Loading tokenizer...")
|
|
@@ -140,70 +133,15 @@ tokenizer = MultilingualTokenizer()
|
|
| 140 |
CONFIG["vocab_size"] = tokenizer.padded_vocab_size
|
| 141 |
|
| 142 |
print("Creating model...")
|
| 143 |
-
model = MiniGPT(
|
| 144 |
-
vocab_size=CONFIG["vocab_size"],
|
| 145 |
-
max_len=CONFIG["max_len"],
|
| 146 |
-
embed_dim=CONFIG["embed_dim"],
|
| 147 |
-
num_heads=CONFIG["num_heads"],
|
| 148 |
-
num_layers=CONFIG["num_layers"],
|
| 149 |
-
ff_dim=CONFIG["ff_dim"],
|
| 150 |
-
dropout_rate=CONFIG["dropout_rate"]
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
print("Downloading checkpoint from HuggingFace...")
|
| 154 |
-
checkpoint_dir = snapshot_download(
|
| 155 |
-
repo_id=REPO_ID,
|
| 156 |
-
repo_type="model",
|
| 157 |
-
allow_patterns=["checkpoint/*"]
|
| 158 |
-
)
|
| 159 |
-
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
| 160 |
-
print(f"Downloaded to: {checkpoint_path}")
|
| 161 |
-
|
| 162 |
-
print("Patching sharding metadata (TPU -> CPU)...")
|
| 163 |
-
# 创建临时副本来修改 sharding 信息
|
| 164 |
-
patched_checkpoint_path = "/tmp/mini_gpt_checkpoint_patched"
|
| 165 |
-
if os.path.exists(patched_checkpoint_path):
|
| 166 |
-
shutil.rmtree(patched_checkpoint_path)
|
| 167 |
-
shutil.copytree(checkpoint_path, patched_checkpoint_path, dirs_exist_ok=True)
|
| 168 |
-
|
| 169 |
-
# 修改 _sharding 文件
|
| 170 |
-
sharding_path = os.path.join(patched_checkpoint_path, "_sharding")
|
| 171 |
-
if os.path.exists(sharding_path):
|
| 172 |
-
with open(sharding_path, 'r') as f:
|
| 173 |
-
sharding_data = json.load(f)
|
| 174 |
-
|
| 175 |
-
cpu_device = jax.devices('cpu')[0]
|
| 176 |
-
cpu_device_str = str(cpu_device)
|
| 177 |
-
|
| 178 |
-
new_sharding = {}
|
| 179 |
-
for key, value in sharding_data.items():
|
| 180 |
-
value_dict = json.loads(value)
|
| 181 |
-
value_dict['device_str'] = cpu_device_str
|
| 182 |
-
new_sharding[key] = json.dumps(value_dict)
|
| 183 |
-
|
| 184 |
-
with open(sharding_path, 'w') as f:
|
| 185 |
-
json.dump(new_sharding, f)
|
| 186 |
-
|
| 187 |
-
print(f"✓ Patched {len(new_sharding)} sharding entries to use CPU")
|
| 188 |
|
| 189 |
-
print("
|
| 190 |
-
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
if 'params' in state:
|
| 195 |
-
params = state['params']
|
| 196 |
-
print("✓ Model loaded successfully!")
|
| 197 |
-
else:
|
| 198 |
-
raise ValueError("params not found in checkpoint")
|
| 199 |
-
except Exception as e:
|
| 200 |
-
print(f"Error: {e}")
|
| 201 |
-
print("Using randomly initialized parameters")
|
| 202 |
-
rng = jax.random.PRNGKey(0)
|
| 203 |
-
dummy_input = jnp.ones((1, CONFIG["max_len"]), dtype=jnp.int32)
|
| 204 |
-
params = model.init(rng, dummy_input, training=False)['params']
|
| 205 |
-
|
| 206 |
-
print(f"Params: {sum(x.size for x in jax.tree.leaves(params)):,} parameters")
|
| 207 |
|
| 208 |
|
| 209 |
# ============================================================================
|
|
@@ -240,9 +178,11 @@ def gradio_generate(prompt, max_tokens, temperature):
|
|
| 240 |
|
| 241 |
with gr.Blocks(title="Mini-GPT 文本生成", theme=gr.themes.Soft()) as demo:
|
| 242 |
gr.Markdown("""
|
| 243 |
-
# 🤖 Mini-GPT 文本生成
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
使用
|
| 246 |
""")
|
| 247 |
|
| 248 |
with gr.Row():
|
|
|
|
| 1 |
"""
|
| 2 |
HuggingFace Spaces Gradio App for Mini-GPT
|
| 3 |
+
使用随机初始化参数的 Demo 版本
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import jax
|
| 8 |
import jax.numpy as jnp
|
| 9 |
import flax.linen as nn
|
|
|
|
|
|
|
| 10 |
from typing import List
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# ============================================================================
|
| 13 |
+
# 模型定义
|
| 14 |
# ============================================================================
|
| 15 |
|
| 16 |
class TokenAndPositionEmbedding(nn.Module):
|
|
|
|
| 123 |
"dropout_rate": 0.1,
|
| 124 |
}
|
| 125 |
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# ============================================================================
|
| 128 |
+
# 加载模型 (随机初始化)
|
| 129 |
# ============================================================================
|
| 130 |
|
| 131 |
print("Loading tokenizer...")
|
|
|
|
| 133 |
CONFIG["vocab_size"] = tokenizer.padded_vocab_size
|
| 134 |
|
| 135 |
print("Creating model...")
|
| 136 |
+
model = MiniGPT(**CONFIG)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
print("Initializing random parameters...")
|
| 139 |
+
rng = jax.random.PRNGKey(42)
|
| 140 |
+
dummy_input = jnp.ones((1, CONFIG["max_len"]), dtype=jnp.int32)
|
| 141 |
+
params = model.init(rng, dummy_input, training=False)['params']
|
| 142 |
|
| 143 |
+
print(f"✓ Model ready with {sum(x.size for x in jax.tree.leaves(params)):,} parameters")
|
| 144 |
+
print("⚠️ Note: Using random parameters (trained weights pending checkpoint fix)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
# ============================================================================
|
|
|
|
| 178 |
|
| 179 |
with gr.Blocks(title="Mini-GPT 文本生成", theme=gr.themes.Soft()) as demo:
|
| 180 |
gr.Markdown("""
|
| 181 |
+
# 🤖 Mini-GPT 文本生成 (Demo)
|
| 182 |
+
|
| 183 |
+
使用 JAX/Flax 构建的小型 GPT 模型。
|
| 184 |
|
| 185 |
+
⚠️ **当前使用随机初始化参数** - 训练好的模型 checkpoint 正在修复中。
|
| 186 |
""")
|
| 187 |
|
| 188 |
with gr.Row():
|
requirements.txt
CHANGED
|
@@ -2,7 +2,6 @@ gradio==4.44.0
|
|
| 2 |
jax==0.4.35
|
| 3 |
jaxlib==0.4.35
|
| 4 |
flax==0.10.2
|
| 5 |
-
orbax-checkpoint==0.10.2
|
| 6 |
transformers==4.47.0
|
| 7 |
huggingface_hub>=0.23.0
|
| 8 |
numpy==1.26.4
|
|
|
|
| 2 |
jax==0.4.35
|
| 3 |
jaxlib==0.4.35
|
| 4 |
flax==0.10.2
|
|
|
|
| 5 |
transformers==4.47.0
|
| 6 |
huggingface_hub>=0.23.0
|
| 7 |
numpy==1.26.4
|