yujiepan commited on
Commit
ffc03fb
ยท
1 Parent(s): f04497e

Add real-time log updates and multi-line tensor preview

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -6,6 +6,8 @@ import shutil
6
  import logging
7
  import io
8
  import sys
 
 
9
  from contextlib import redirect_stdout, redirect_stderr
10
 
11
  import gradio as gr
@@ -109,30 +111,53 @@ def format_tensor_info(tensor: torch.Tensor) -> str:
109
  def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
110
  """Fetch parameter and return formatted info and tensor preview."""
111
  log_buffer = io.StringIO()
 
112
 
113
  if not model_id or not param_key:
114
- return "Please provide both model ID and parameter key.", "", None, "โŒ Missing required inputs"
 
115
 
116
  try:
117
  log_buffer.write(f"๐Ÿš€ Starting download for {model_id}\n")
118
  log_buffer.write(f"๐ŸŽฏ Target parameter: {param_key}\n\n")
119
  progress(0, desc="Initializing...")
 
 
120
 
 
121
  tensor = get_param(model_id, param_key, log_buffer, progress)
 
 
 
 
 
 
 
122
  info = format_tensor_info(tensor)
123
 
124
  # Create tensor preview (first few elements)
125
  log_buffer.write(f"\n๐Ÿ“Š Creating preview...\n")
 
 
126
  flat = tensor.flatten()
127
  preview_size = min(100, flat.numel())
128
  preview = flat[:preview_size].tolist()
129
- preview_str = f"**First {preview_size} values:**\n```\n{preview}\n```"
 
 
 
 
 
 
 
130
 
131
  if flat.numel() > preview_size:
132
  preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values"
133
 
134
  # Save tensor for download
135
  log_buffer.write(f"๐Ÿ’พ Saving tensor for download...\n")
 
 
136
  temp_dir = tempfile.gettempdir()
137
  safe_param_key = param_key.replace("/", "_").replace(".", "_")
138
  download_path = os.path.join(temp_dir, f"{safe_param_key}.pt")
@@ -141,10 +166,10 @@ def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
141
 
142
  progress(1.0, desc="Complete!")
143
  log_buffer.write(f"\nโœ… All operations completed successfully!\n")
144
- return info, preview_str, download_path, log_buffer.getvalue()
145
  except Exception as e:
146
  log_buffer.write(f"\nโŒ Error: {str(e)}\n")
147
- return f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
148
 
149
 
150
  def list_keys(model_id: str):
 
6
  import logging
7
  import io
8
  import sys
9
+ import threading
10
+ import time
11
  from contextlib import redirect_stdout, redirect_stderr
12
 
13
  import gradio as gr
 
111
  def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
112
  """Fetch parameter and return formatted info and tensor preview."""
113
  log_buffer = io.StringIO()
114
+ last_log_value = ""
115
 
116
  if not model_id or not param_key:
117
+ yield "Please provide both model ID and parameter key.", "", None, "โŒ Missing required inputs"
118
+ return
119
 
120
  try:
121
  log_buffer.write(f"๐Ÿš€ Starting download for {model_id}\n")
122
  log_buffer.write(f"๐ŸŽฏ Target parameter: {param_key}\n\n")
123
  progress(0, desc="Initializing...")
124
+ yield "", "", None, log_buffer.getvalue()
125
+ time.sleep(0.5)
126
 
127
+ # Start download in background and monitor logs
128
  tensor = get_param(model_id, param_key, log_buffer, progress)
129
+
130
+ # Yield log updates periodically during download
131
+ current_log = log_buffer.getvalue()
132
+ if current_log != last_log_value:
133
+ yield "", "", None, current_log
134
+ last_log_value = current_log
135
+
136
  info = format_tensor_info(tensor)
137
 
138
  # Create tensor preview (first few elements)
139
  log_buffer.write(f"\n๐Ÿ“Š Creating preview...\n")
140
+ yield "", "", None, log_buffer.getvalue()
141
+
142
  flat = tensor.flatten()
143
  preview_size = min(100, flat.numel())
144
  preview = flat[:preview_size].tolist()
145
+
146
+ # Format preview in multiple lines (10 values per line)
147
+ preview_lines = []
148
+ for i in range(0, len(preview), 10):
149
+ line_values = preview[i:i+10]
150
+ preview_lines.append(", ".join(f"{v:.6f}" for v in line_values))
151
+
152
+ preview_str = f"**First {preview_size} values:**\n```\n" + "\n".join(preview_lines) + "\n```"
153
 
154
  if flat.numel() > preview_size:
155
  preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values"
156
 
157
  # Save tensor for download
158
  log_buffer.write(f"๐Ÿ’พ Saving tensor for download...\n")
159
+ yield info, preview_str, None, log_buffer.getvalue()
160
+
161
  temp_dir = tempfile.gettempdir()
162
  safe_param_key = param_key.replace("/", "_").replace(".", "_")
163
  download_path = os.path.join(temp_dir, f"{safe_param_key}.pt")
 
166
 
167
  progress(1.0, desc="Complete!")
168
  log_buffer.write(f"\nโœ… All operations completed successfully!\n")
169
+ yield info, preview_str, download_path, log_buffer.getvalue()
170
  except Exception as e:
171
  log_buffer.write(f"\nโŒ Error: {str(e)}\n")
172
+ yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
173
 
174
 
175
  def list_keys(model_id: str):