Spaces:
Runtime error
Runtime error
Add real-time log updates and multi-line tensor preview
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ import shutil
|
|
| 6 |
import logging
|
| 7 |
import io
|
| 8 |
import sys
|
|
|
|
|
|
|
| 9 |
from contextlib import redirect_stdout, redirect_stderr
|
| 10 |
|
| 11 |
import gradio as gr
|
|
@@ -109,30 +111,53 @@ def format_tensor_info(tensor: torch.Tensor) -> str:
|
|
| 109 |
def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
|
| 110 |
"""Fetch parameter and return formatted info and tensor preview."""
|
| 111 |
log_buffer = io.StringIO()
|
|
|
|
| 112 |
|
| 113 |
if not model_id or not param_key:
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
log_buffer.write(f"๐ Starting download for {model_id}\n")
|
| 118 |
log_buffer.write(f"๐ฏ Target parameter: {param_key}\n\n")
|
| 119 |
progress(0, desc="Initializing...")
|
|
|
|
|
|
|
| 120 |
|
|
|
|
| 121 |
tensor = get_param(model_id, param_key, log_buffer, progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
info = format_tensor_info(tensor)
|
| 123 |
|
| 124 |
# Create tensor preview (first few elements)
|
| 125 |
log_buffer.write(f"\n๐ Creating preview...\n")
|
|
|
|
|
|
|
| 126 |
flat = tensor.flatten()
|
| 127 |
preview_size = min(100, flat.numel())
|
| 128 |
preview = flat[:preview_size].tolist()
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
if flat.numel() > preview_size:
|
| 132 |
preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values"
|
| 133 |
|
| 134 |
# Save tensor for download
|
| 135 |
log_buffer.write(f"๐พ Saving tensor for download...\n")
|
|
|
|
|
|
|
| 136 |
temp_dir = tempfile.gettempdir()
|
| 137 |
safe_param_key = param_key.replace("/", "_").replace(".", "_")
|
| 138 |
download_path = os.path.join(temp_dir, f"{safe_param_key}.pt")
|
|
@@ -141,10 +166,10 @@ def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
|
|
| 141 |
|
| 142 |
progress(1.0, desc="Complete!")
|
| 143 |
log_buffer.write(f"\nโ
All operations completed successfully!\n")
|
| 144 |
-
|
| 145 |
except Exception as e:
|
| 146 |
log_buffer.write(f"\nโ Error: {str(e)}\n")
|
| 147 |
-
|
| 148 |
|
| 149 |
|
| 150 |
def list_keys(model_id: str):
|
|
|
|
| 6 |
import logging
|
| 7 |
import io
|
| 8 |
import sys
|
| 9 |
+
import threading
|
| 10 |
+
import time
|
| 11 |
from contextlib import redirect_stdout, redirect_stderr
|
| 12 |
|
| 13 |
import gradio as gr
|
|
|
|
| 111 |
def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
|
| 112 |
"""Fetch parameter and return formatted info and tensor preview."""
|
| 113 |
log_buffer = io.StringIO()
|
| 114 |
+
last_log_value = ""
|
| 115 |
|
| 116 |
if not model_id or not param_key:
|
| 117 |
+
yield "Please provide both model ID and parameter key.", "", None, "โ Missing required inputs"
|
| 118 |
+
return
|
| 119 |
|
| 120 |
try:
|
| 121 |
log_buffer.write(f"๐ Starting download for {model_id}\n")
|
| 122 |
log_buffer.write(f"๐ฏ Target parameter: {param_key}\n\n")
|
| 123 |
progress(0, desc="Initializing...")
|
| 124 |
+
yield "", "", None, log_buffer.getvalue()
|
| 125 |
+
time.sleep(0.5)
|
| 126 |
|
| 127 |
+
# Start download in background and monitor logs
|
| 128 |
tensor = get_param(model_id, param_key, log_buffer, progress)
|
| 129 |
+
|
| 130 |
+
# Yield log updates periodically during download
|
| 131 |
+
current_log = log_buffer.getvalue()
|
| 132 |
+
if current_log != last_log_value:
|
| 133 |
+
yield "", "", None, current_log
|
| 134 |
+
last_log_value = current_log
|
| 135 |
+
|
| 136 |
info = format_tensor_info(tensor)
|
| 137 |
|
| 138 |
# Create tensor preview (first few elements)
|
| 139 |
log_buffer.write(f"\n๐ Creating preview...\n")
|
| 140 |
+
yield "", "", None, log_buffer.getvalue()
|
| 141 |
+
|
| 142 |
flat = tensor.flatten()
|
| 143 |
preview_size = min(100, flat.numel())
|
| 144 |
preview = flat[:preview_size].tolist()
|
| 145 |
+
|
| 146 |
+
# Format preview in multiple lines (10 values per line)
|
| 147 |
+
preview_lines = []
|
| 148 |
+
for i in range(0, len(preview), 10):
|
| 149 |
+
line_values = preview[i:i+10]
|
| 150 |
+
preview_lines.append(", ".join(f"{v:.6f}" for v in line_values))
|
| 151 |
+
|
| 152 |
+
preview_str = f"**First {preview_size} values:**\n```\n" + "\n".join(preview_lines) + "\n```"
|
| 153 |
|
| 154 |
if flat.numel() > preview_size:
|
| 155 |
preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values"
|
| 156 |
|
| 157 |
# Save tensor for download
|
| 158 |
log_buffer.write(f"๐พ Saving tensor for download...\n")
|
| 159 |
+
yield info, preview_str, None, log_buffer.getvalue()
|
| 160 |
+
|
| 161 |
temp_dir = tempfile.gettempdir()
|
| 162 |
safe_param_key = param_key.replace("/", "_").replace(".", "_")
|
| 163 |
download_path = os.path.join(temp_dir, f"{safe_param_key}.pt")
|
|
|
|
| 166 |
|
| 167 |
progress(1.0, desc="Complete!")
|
| 168 |
log_buffer.write(f"\nโ
All operations completed successfully!\n")
|
| 169 |
+
yield info, preview_str, download_path, log_buffer.getvalue()
|
| 170 |
except Exception as e:
|
| 171 |
log_buffer.write(f"\nโ Error: {str(e)}\n")
|
| 172 |
+
yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
|
| 173 |
|
| 174 |
|
| 175 |
def list_keys(model_id: str):
|