TimesNet-Gen / app.py
Barisylmz's picture
Upload app.py
f181668 verified
#!/usr/bin/env python3
"""
Gradio app for TimesNet-Gen: Generate seismic samples from latent bank.
Based on generate_samples_git.py (working GitHub version).
NO PLOTTING - only NPZ generation and display in Gradio interface.
Model files are automatically downloaded from Hugging Face Hub.
"""
import os
import gradio as gr
import torch
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
from huggingface_hub import hf_hub_download
# Hugging Face model repository
HF_REPO_ID = "Barisylmz/TimesNet-Gen-Models" # Your HF model repo
CHECKPOINT_FILENAME = "timesnet_pointcloud_phase1_final.pth"
LATENT_BANK_FILENAME = "latent_bank_station_cond.npz" # Actual filename on HF Hub
def download_model_files():
"""
Download model files from Hugging Face Hub if not present locally.
Returns:
checkpoint_path: Path to downloaded checkpoint
latent_bank_path: Path to downloaded latent bank
"""
print("[INFO] Checking for model files...")
# Check if files exist locally first
if os.path.exists(CHECKPOINT_FILENAME) and os.path.exists(LATENT_BANK_FILENAME):
print("[INFO] Model files found locally")
return CHECKPOINT_FILENAME, LATENT_BANK_FILENAME
print(f"[INFO] Downloading model files from Hugging Face Hub: {HF_REPO_ID}")
try:
# Download checkpoint
if not os.path.exists(CHECKPOINT_FILENAME):
print(f"[INFO] Downloading {CHECKPOINT_FILENAME}...")
checkpoint_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=CHECKPOINT_FILENAME,
cache_dir="./hf_cache"
)
print(f"[INFO] βœ… Checkpoint downloaded to {checkpoint_path}")
else:
checkpoint_path = CHECKPOINT_FILENAME
# Download latent bank
if not os.path.exists(LATENT_BANK_FILENAME):
print(f"[INFO] Downloading {LATENT_BANK_FILENAME}...")
latent_bank_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=LATENT_BANK_FILENAME,
cache_dir="./hf_cache"
)
print(f"[INFO] βœ… Latent bank downloaded to {latent_bank_path}")
else:
latent_bank_path = LATENT_BANK_FILENAME
return checkpoint_path, latent_bank_path
except Exception as e:
print(f"[ERROR] Failed to download model files: {e}")
print(f"[INFO] Please ensure files exist at: https://huggingface.co/{HF_REPO_ID}")
raise
class SimpleArgs:
"""Configuration for generation (matching GitHub version)."""
def __init__(self):
# Model architecture
self.seq_len = 6000
self.d_model = 128
self.d_ff = 256
self.e_layers = 2
self.d_layers = 2
self.num_kernels = 6
self.top_k = 2
self.dropout = 0.1
self.latent_dim = 256
# System
self.use_gpu = torch.cuda.is_available()
self.seed = 0
# Point-cloud generation
self.pcgen_k = 5
self.pcgen_jitter_std = 0.0
def load_model(checkpoint_path, args):
"""Load pre-trained TimesNet-PointCloud model (matching GitHub version)."""
from TimesNet_PointCloud import TimesNetPointCloud
# Create model config (NO num_stations - GitHub version doesn't use it)
class ModelConfig:
def __init__(self, args):
self.seq_len = args.seq_len
self.pred_len = 0
self.enc_in = 3
self.c_out = 3
self.d_model = args.d_model
self.d_ff = args.d_ff
self.num_kernels = args.num_kernels
self.top_k = args.top_k
self.e_layers = args.e_layers
self.d_layers = args.d_layers
self.dropout = args.dropout
self.embed = 'timeF'
self.freq = 'h'
self.latent_dim = args.latent_dim
config = ModelConfig(args)
model = TimesNetPointCloud(config)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
if args.use_gpu:
model = model.cuda()
print(f"[INFO] Model loaded successfully from {checkpoint_path}")
return model
def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args):
"""
Generate samples directly from pre-computed latent bank (matching GitHub version).
Args:
model: TimesNet model
latent_bank_path: Path to latent_bank_phase1.npz
station_id: Station ID (e.g., '0205')
num_samples: Number of samples to generate
args: Model arguments
Returns:
generated_signals: (num_samples, 3, seq_len) array
real_names_used: List of lists indicating which latent vectors were used
"""
print(f"[INFO] Loading latent bank from {latent_bank_path}...")
try:
latent_data = np.load(latent_bank_path)
except Exception as e:
print(f"[ERROR] Could not load latent bank: {e}")
return None, None
# Load latent vectors for this station
latents_key = f'latents_{station_id}'
means_key = f'means_{station_id}'
stdev_key = f'stdev_{station_id}'
if latents_key not in latent_data:
print(f"[ERROR] Station {station_id} not found in latent bank!")
available = [k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')]
print(f"Available stations: {available}")
return None, None
latents = latent_data[latents_key] # (N_samples, seq_len, d_model)
means = latent_data[means_key] # (N_samples, seq_len, d_model)
stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model)
print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}")
print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...")
generated_signals = []
real_names_used = []
model.eval()
with torch.no_grad():
for i in range(num_samples):
# Bootstrap: randomly select k latent vectors with replacement
k = min(args.pcgen_k, len(latents))
selected_indices = np.random.choice(len(latents), size=k, replace=True)
# Mix latent features (average)
selected_latents = latents[selected_indices] # (k, seq_len, d_model)
selected_means = means[selected_indices] # (k, seq_len, d_model)
selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model)
mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model)
mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model)
mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model)
# Convert to torch tensors
mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model)
means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model)
stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model)
if args.use_gpu:
mixed_features_torch = mixed_features_torch.cuda()
means_b = means_b.cuda()
stdev_b = stdev_b.cuda()
# Decode
xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b)
# Store - transpose to (3, 6000)
generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) β†’ (3, 6000)
generated_signals.append(generated_np)
# Track which latent indices were used
real_names_used.append([f"latent_{idx}" for idx in selected_indices])
if (i + 1) % 10 == 0:
print(f" Generated {i + 1}/{num_samples} samples...")
return np.array(generated_signals), real_names_used
def save_generated_samples(generated_signals, real_names, station_id, output_dir):
"""Save generated samples to NPZ file (NO PLOTTING)."""
os.makedirs(output_dir, exist_ok=True)
# Save timeseries NPZ
output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz')
np.savez_compressed(
output_path,
generated_signals=generated_signals,
signals_generated=generated_signals, # Alias for compatibility
real_names=real_names,
station_id=station_id,
station=station_id, # Alias for compatibility
)
print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}")
return output_path
def plot_signal_for_display(signal):
"""
Plot a single 3-component seismic signal for Gradio display.
Args:
signal: (3, 6000) array [E, N, Z]
Returns:
PIL Image
"""
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
component_names = ['East', 'North', 'Vertical']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
t = np.arange(signal.shape[1]) / 100.0 # 100 Hz sampling
for idx, (ax, name, color) in enumerate(zip(axes, component_names, colors)):
ax.plot(t, signal[idx], color=color, linewidth=0.5, alpha=0.8)
ax.set_ylabel(f'{name}\n(cm/sΒ²)', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 60)
axes[-1].set_xlabel('Time (s)', fontsize=11)
fig.suptitle('Generated Seismic Signal (3 Components)', fontsize=13, fontweight='bold')
plt.tight_layout()
# Convert to PIL Image
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# ===========================
# Gradio Interface Functions
# ===========================
def generate_samples_interface(station_id, num_samples, progress=gr.Progress()):
"""
Generate samples for Gradio interface.
Args:
station_id: Station ID (e.g., '0205')
num_samples: Number of samples to generate
progress: Gradio progress tracker
Returns:
status_message: Generation status
npz_path: Path to saved NPZ file
sample_plot: Preview plot of first generated sample
"""
try:
# Convert num_samples to int (Gradio might pass float)
num_samples = int(num_samples)
print(f"[DEBUG] Requested num_samples: {num_samples} (type: {type(num_samples)})")
progress(0, desc="Checking model files...")
# Download model files from HF Hub if needed
try:
checkpoint_path, latent_bank_path = download_model_files()
except Exception as e:
error_msg = f"❌ Error downloading model files:\n{str(e)}\n\n"
error_msg += "Please ensure model files are uploaded to:\n"
error_msg += f"https://huggingface.co/{HF_REPO_ID}"
return error_msg, None, None
output_dir = 'generated_outputs'
progress(0.1, desc="Loading model...")
# Load model
args = SimpleArgs()
model = load_model(checkpoint_path, args)
progress(0.3, desc=f"Generating {num_samples} samples...")
# Generate samples
generated_signals, real_names = generate_samples_from_latent_bank(
model, latent_bank_path, station_id, num_samples, args
)
if generated_signals is None:
return f"❌ Error: Failed to generate samples for station {station_id}", None, None
progress(0.8, desc="Saving NPZ file...")
# Save NPZ (NO PLOTTING)
npz_path = save_generated_samples(generated_signals, real_names, station_id, output_dir)
progress(0.95, desc="Creating preview plot...")
# Create preview plot for first sample
sample_plot = plot_signal_for_display(generated_signals[0])
progress(1.0, desc="Done!")
# Verify actual number of samples generated
actual_count = len(generated_signals)
status_msg = f"βœ… Successfully generated {actual_count} samples for station {station_id}!\n"
status_msg += f"πŸ“Š Requested: {num_samples} samples\n"
status_msg += f"πŸ“ Saved to: {npz_path}\n"
status_msg += f"πŸ“ˆ Preview of first generated sample shown below."
return status_msg, npz_path, sample_plot
except Exception as e:
import traceback
error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}"
return error_msg, None, None
def load_and_display_npz(npz_file, sample_idx):
"""
Load NPZ file and display a specific sample.
Args:
npz_file: Path to NPZ file
sample_idx: Index of sample to display (0-based)
Returns:
status_message: Load status
sample_plot: Plot of selected sample
"""
try:
if npz_file is None:
return "⚠️ No NPZ file provided", None
# Load NPZ
data = np.load(npz_file)
generated_signals = data['generated_signals']
if sample_idx < 0 or sample_idx >= len(generated_signals):
return f"⚠️ Sample index {sample_idx} out of range (0-{len(generated_signals)-1})", None
# Plot selected sample
sample_plot = plot_signal_for_display(generated_signals[sample_idx])
status_msg = f"βœ… Loaded NPZ with {len(generated_signals)} samples\n"
status_msg += f"πŸ“Š Displaying sample #{sample_idx}"
return status_msg, sample_plot
except Exception as e:
import traceback
error_msg = f"❌ Error loading NPZ:\n{str(e)}\n\n{traceback.format_exc()}"
return error_msg, None
# ===========================
# Gradio App
# ===========================
def create_demo():
"""Create Gradio interface."""
with gr.Blocks(title="TimesNet-Gen Demo: Station-Specific Seismic Generator") as demo:
gr.Markdown("""
# 🌍 TimesNet-Gen Demo: Station-Specific Seismic Sample Generator
For more detailed information: https://arxiv.org/abs/2512.04694
Generate realistic synthetic seismic signals with station-specific characteristics.
**Instructions:**
1. Select a station ID (5 fine-tuned stations available)
2. Choose number of samples to generate
3. Click "Generate Samples" and wait
4. Preview generated samples or download NPZ file
""")
with gr.Tab("Generate Samples"):
with gr.Row():
with gr.Column(scale=1):
station_dropdown = gr.Dropdown(
choices=['0205', '1716', '2020', '3130', '4628'],
value='0205',
label="Station ID",
info="Select target station"
)
num_samples_slider = gr.Slider(
minimum=1,
maximum=200,
value=50,
step=1,
label="Number of Samples",
info="How many samples to generate"
)
generate_btn = gr.Button("πŸš€ Generate Samples", variant="primary")
with gr.Column(scale=2):
status_text = gr.Textbox(
label="Status",
lines=5,
interactive=False
)
npz_file_output = gr.File(
label="Generated NPZ File",
interactive=False
)
gr.Markdown("### Preview (First Generated Sample)")
preview_plot = gr.Image(label="Sample Preview")
generate_btn.click(
fn=generate_samples_interface,
inputs=[station_dropdown, num_samples_slider],
outputs=[status_text, npz_file_output, preview_plot]
)
with gr.Tab("View Saved Samples"):
gr.Markdown("### Load and view samples from saved NPZ file")
with gr.Row():
with gr.Column(scale=1):
npz_upload = gr.File(
label="Upload NPZ File",
file_types=['.npz']
)
sample_idx_slider = gr.Slider(
minimum=0,
maximum=199,
value=0,
step=1,
label="Sample Index",
info="Which sample to display (0-based)"
)
load_btn = gr.Button("πŸ“Š Load and Display", variant="secondary")
with gr.Column(scale=2):
load_status = gr.Textbox(
label="Status",
lines=3,
interactive=False
)
display_plot = gr.Image(label="Sample Display")
load_btn.click(
fn=load_and_display_npz,
inputs=[npz_upload, sample_idx_slider],
outputs=[load_status, display_plot]
)
gr.Markdown("""
---
**Model:** TimesNet-PointCloud
**Method:** Bootstrap aggregation from latent bank
**Stations:** 5 fine-tuned Turkish strong-motion stations
**Output:** 3-component acceleration signals (E, N, Z) @ 100 Hz
""")
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860)