yujiepan commited on
Commit
5e22042
Β·
1 Parent(s): 6dc6bb2

Fix FP8 dtype support and real-time stderr logging

Browse files
Files changed (1) hide show
  1. app.py +63 -52
app.py CHANGED
@@ -6,7 +6,7 @@ import shutil
6
  import io
7
  import time
8
  import threading
9
- from contextlib import redirect_stderr
10
 
11
  import gradio as gr
12
  import torch
@@ -18,61 +18,55 @@ def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress:
18
  """
19
  Download and return a specific parameter tensor from a Hugging Face model.
20
  """
21
- # Try to download the index file (for sharded models)
 
 
 
22
  try:
23
- log_buffer.write(f"πŸ“₯ Downloading index file for {model_id}...\n")
24
- progress(0.1, desc="Downloading index...")
 
 
25
 
26
- # Capture tqdm output from stderr
27
- stderr_capture = io.StringIO()
28
- with redirect_stderr(stderr_capture):
29
  index_path = hf_hub_download(
30
  model_id, "model.safetensors.index.json")
31
 
32
- stderr_output = stderr_capture.getvalue()
33
- if stderr_output:
34
- log_buffer.write(stderr_output + "\n")
35
-
36
- log_buffer.write(f"βœ“ Index file found: {index_path}\n")
37
-
38
- with open(index_path, "r", encoding="utf-8") as f:
39
- index = json.load(f)
40
- weight_map = index["weight_map"]
41
- if param_key not in weight_map:
42
- raise KeyError(
43
- f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..."
44
- )
45
- shard_file = weight_map[param_key]
46
- log_buffer.write(f"βœ“ Parameter found in shard: {shard_file}\n")
47
- except Exception as e:
48
- if "404" in str(e) or "not found" in str(e).lower():
49
- log_buffer.write("ℹ️ No index file, trying single model file...\n")
50
- shard_file = "model.safetensors"
51
- else:
52
- raise
53
 
54
- log_buffer.write(f"πŸ“₯ Downloading shard: {shard_file}...\n")
55
- progress(0.3, desc=f"Downloading {shard_file}...")
56
 
57
- # Capture download progress
58
- stderr_capture = io.StringIO()
59
- with redirect_stderr(stderr_capture):
60
  shard_path = hf_hub_download(model_id, shard_file)
61
 
62
- stderr_output = stderr_capture.getvalue()
63
- if stderr_output:
64
- log_buffer.write(stderr_output + "\n")
65
 
66
- log_buffer.write(f"βœ“ Shard downloaded: {shard_path}\n")
67
- progress(0.7, desc="Loading tensor...")
 
 
 
68
 
69
- log_buffer.write(f"πŸ” Loading tensor '{param_key}'...\n")
70
- with safe_open(shard_path, framework="pt") as f:
71
- tensor = f.get_tensor(param_key)
72
- log_buffer.write(f"βœ“ Tensor loaded successfully\n")
73
- progress(0.9, desc="Finalizing...")
74
-
75
- return tensor
76
 
77
 
78
  def get_available_keys(model_id: str):
@@ -99,10 +93,22 @@ def format_tensor_info(tensor: torch.Tensor) -> str:
99
  info.append(f"**Dtype:** {tensor.dtype}")
100
  info.append(f"**Device:** {tensor.device}")
101
  info.append(f"**Numel:** {tensor.numel():,}")
102
- info.append(f"**Min:** {tensor.min().item():.6f}")
103
- info.append(f"**Max:** {tensor.max().item():.6f}")
104
- info.append(f"**Mean:** {tensor.float().mean().item():.6f}")
105
- info.append(f"**Std:** {tensor.float().std().item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
106
  return "<br>".join(info)
107
 
108
 
@@ -165,14 +171,19 @@ def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
165
 
166
  flat = tensor.flatten()
167
  preview_size = min(100, flat.numel())
168
- preview = flat[:preview_size].tolist()
 
 
 
 
 
169
 
170
  # Format preview in multiple lines (10 values per line)
171
  # Adapt to different data types
172
  preview_lines = []
173
  for i in range(0, len(preview), 10):
174
  line_values = preview[i:i+10]
175
- if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
176
  preview_lines.append(", ".join(f"{v:.6f}" for v in line_values))
177
  elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
178
  preview_lines.append(", ".join(f"{v}" for v in line_values))
@@ -358,7 +369,7 @@ with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
358
  preview_output = gr.Markdown(label="Tensor Preview")
359
  download_output = gr.File(label="Download Tensor (.pt file)")
360
  log_output = gr.Textbox(
361
- label="πŸ“‹ Download Log", lines=6, interactive=False)
362
 
363
  with gr.Tab("Cache Management"):
364
  with gr.Row():
 
6
  import io
7
  import time
8
  import threading
9
+ import sys
10
 
11
  import gradio as gr
12
  import torch
 
18
  """
19
  Download and return a specific parameter tensor from a Hugging Face model.
20
  """
21
+ # Redirect stderr to log buffer for real-time tqdm updates
22
+ original_stderr = sys.stderr
23
+ sys.stderr = log_buffer
24
+
25
  try:
26
+ # Try to download the index file (for sharded models)
27
+ try:
28
+ log_buffer.write(f"πŸ“₯ Downloading index file for {model_id}...\n")
29
+ progress(0.1, desc="Downloading index...")
30
 
 
 
 
31
  index_path = hf_hub_download(
32
  model_id, "model.safetensors.index.json")
33
 
34
+ log_buffer.write(f"βœ“ Index file found: {index_path}\n")
35
+
36
+ with open(index_path, "r", encoding="utf-8") as f:
37
+ index = json.load(f)
38
+ weight_map = index["weight_map"]
39
+ if param_key not in weight_map:
40
+ raise KeyError(
41
+ f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..."
42
+ )
43
+ shard_file = weight_map[param_key]
44
+ log_buffer.write(f"βœ“ Parameter found in shard: {shard_file}\n")
45
+ except Exception as e:
46
+ if "404" in str(e) or "not found" in str(e).lower():
47
+ log_buffer.write("ℹ️ No index file, trying single model file...\n")
48
+ shard_file = "model.safetensors"
49
+ else:
50
+ raise
 
 
 
 
51
 
52
+ log_buffer.write(f"πŸ“₯ Downloading shard: {shard_file}...\n")
53
+ progress(0.3, desc=f"Downloading {shard_file}...")
54
 
 
 
 
55
  shard_path = hf_hub_download(model_id, shard_file)
56
 
57
+ log_buffer.write(f"\nβœ“ Shard downloaded: {shard_path}\n")
58
+ progress(0.7, desc="Loading tensor...")
 
59
 
60
+ log_buffer.write(f"πŸ” Loading tensor '{param_key}'...\n")
61
+ with safe_open(shard_path, framework="pt") as f:
62
+ tensor = f.get_tensor(param_key)
63
+ log_buffer.write(f"βœ“ Tensor loaded successfully\n")
64
+ progress(0.9, desc="Finalizing...")
65
 
66
+ return tensor
67
+ finally:
68
+ # Restore original stderr
69
+ sys.stderr = original_stderr
 
 
 
70
 
71
 
72
  def get_available_keys(model_id: str):
 
93
  info.append(f"**Dtype:** {tensor.dtype}")
94
  info.append(f"**Device:** {tensor.device}")
95
  info.append(f"**Numel:** {tensor.numel():,}")
96
+
97
+ # Handle special dtypes that don't support statistical operations
98
+ try:
99
+ # Convert FP8 and other special dtypes to float32 for stats
100
+ if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
101
+ stats_tensor = tensor.to(torch.float32)
102
+ else:
103
+ stats_tensor = tensor
104
+
105
+ info.append(f"**Min:** {stats_tensor.min().item():.6f}")
106
+ info.append(f"**Max:** {stats_tensor.max().item():.6f}")
107
+ info.append(f"**Mean:** {stats_tensor.float().mean().item():.6f}")
108
+ info.append(f"**Std:** {stats_tensor.float().std().item():.6f}")
109
+ except Exception as e:
110
+ info.append(f"**Stats:** Unable to compute (dtype not supported)")
111
+
112
  return "<br>".join(info)
113
 
114
 
 
171
 
172
  flat = tensor.flatten()
173
  preview_size = min(100, flat.numel())
174
+
175
+ # Convert to float32 for FP8 types for display
176
+ if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
177
+ preview = flat[:preview_size].to(torch.float32).tolist()
178
+ else:
179
+ preview = flat[:preview_size].tolist()
180
 
181
  # Format preview in multiple lines (10 values per line)
182
  # Adapt to different data types
183
  preview_lines = []
184
  for i in range(0, len(preview), 10):
185
  line_values = preview[i:i+10]
186
+ if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16] or str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
187
  preview_lines.append(", ".join(f"{v:.6f}" for v in line_values))
188
  elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
189
  preview_lines.append(", ".join(f"{v}" for v in line_values))
 
369
  preview_output = gr.Markdown(label="Tensor Preview")
370
  download_output = gr.File(label="Download Tensor (.pt file)")
371
  log_output = gr.Textbox(
372
+ label="πŸ“‹ Download Log", lines=1, interactive=False)
373
 
374
  with gr.Tab("Cache Management"):
375
  with gr.Row():