Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio app for TimesNet-Gen: Generate seismic samples from latent bank. | |
| Based on generate_samples_git.py (working GitHub version). | |
| NO PLOTTING - only NPZ generation and display in Gradio interface. | |
| Model files are automatically downloaded from Hugging Face Hub. | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # Hugging Face model repository | |
| HF_REPO_ID = "Barisylmz/TimesNet-Gen-Models" # Your HF model repo | |
| CHECKPOINT_FILENAME = "timesnet_pointcloud_phase1_final.pth" | |
| LATENT_BANK_FILENAME = "latent_bank_station_cond.npz" # Actual filename on HF Hub | |
| def download_model_files(): | |
| """ | |
| Download model files from Hugging Face Hub if not present locally. | |
| Returns: | |
| checkpoint_path: Path to downloaded checkpoint | |
| latent_bank_path: Path to downloaded latent bank | |
| """ | |
| print("[INFO] Checking for model files...") | |
| # Check if files exist locally first | |
| if os.path.exists(CHECKPOINT_FILENAME) and os.path.exists(LATENT_BANK_FILENAME): | |
| print("[INFO] Model files found locally") | |
| return CHECKPOINT_FILENAME, LATENT_BANK_FILENAME | |
| print(f"[INFO] Downloading model files from Hugging Face Hub: {HF_REPO_ID}") | |
| try: | |
| # Download checkpoint | |
| if not os.path.exists(CHECKPOINT_FILENAME): | |
| print(f"[INFO] Downloading {CHECKPOINT_FILENAME}...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=CHECKPOINT_FILENAME, | |
| cache_dir="./hf_cache" | |
| ) | |
| print(f"[INFO] β Checkpoint downloaded to {checkpoint_path}") | |
| else: | |
| checkpoint_path = CHECKPOINT_FILENAME | |
| # Download latent bank | |
| if not os.path.exists(LATENT_BANK_FILENAME): | |
| print(f"[INFO] Downloading {LATENT_BANK_FILENAME}...") | |
| latent_bank_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=LATENT_BANK_FILENAME, | |
| cache_dir="./hf_cache" | |
| ) | |
| print(f"[INFO] β Latent bank downloaded to {latent_bank_path}") | |
| else: | |
| latent_bank_path = LATENT_BANK_FILENAME | |
| return checkpoint_path, latent_bank_path | |
| except Exception as e: | |
| print(f"[ERROR] Failed to download model files: {e}") | |
| print(f"[INFO] Please ensure files exist at: https://huggingface.co/{HF_REPO_ID}") | |
| raise | |
| class SimpleArgs: | |
| """Configuration for generation (matching GitHub version).""" | |
| def __init__(self): | |
| # Model architecture | |
| self.seq_len = 6000 | |
| self.d_model = 128 | |
| self.d_ff = 256 | |
| self.e_layers = 2 | |
| self.d_layers = 2 | |
| self.num_kernels = 6 | |
| self.top_k = 2 | |
| self.dropout = 0.1 | |
| self.latent_dim = 256 | |
| # System | |
| self.use_gpu = torch.cuda.is_available() | |
| self.seed = 0 | |
| # Point-cloud generation | |
| self.pcgen_k = 5 | |
| self.pcgen_jitter_std = 0.0 | |
| def load_model(checkpoint_path, args): | |
| """Load pre-trained TimesNet-PointCloud model (matching GitHub version).""" | |
| from TimesNet_PointCloud import TimesNetPointCloud | |
| # Create model config (NO num_stations - GitHub version doesn't use it) | |
| class ModelConfig: | |
| def __init__(self, args): | |
| self.seq_len = args.seq_len | |
| self.pred_len = 0 | |
| self.enc_in = 3 | |
| self.c_out = 3 | |
| self.d_model = args.d_model | |
| self.d_ff = args.d_ff | |
| self.num_kernels = args.num_kernels | |
| self.top_k = args.top_k | |
| self.e_layers = args.e_layers | |
| self.d_layers = args.d_layers | |
| self.dropout = args.dropout | |
| self.embed = 'timeF' | |
| self.freq = 'h' | |
| self.latent_dim = args.latent_dim | |
| config = ModelConfig(args) | |
| model = TimesNetPointCloud(config) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| if args.use_gpu: | |
| model = model.cuda() | |
| print(f"[INFO] Model loaded successfully from {checkpoint_path}") | |
| return model | |
| def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args): | |
| """ | |
| Generate samples directly from pre-computed latent bank (matching GitHub version). | |
| Args: | |
| model: TimesNet model | |
| latent_bank_path: Path to latent_bank_phase1.npz | |
| station_id: Station ID (e.g., '0205') | |
| num_samples: Number of samples to generate | |
| args: Model arguments | |
| Returns: | |
| generated_signals: (num_samples, 3, seq_len) array | |
| real_names_used: List of lists indicating which latent vectors were used | |
| """ | |
| print(f"[INFO] Loading latent bank from {latent_bank_path}...") | |
| try: | |
| latent_data = np.load(latent_bank_path) | |
| except Exception as e: | |
| print(f"[ERROR] Could not load latent bank: {e}") | |
| return None, None | |
| # Load latent vectors for this station | |
| latents_key = f'latents_{station_id}' | |
| means_key = f'means_{station_id}' | |
| stdev_key = f'stdev_{station_id}' | |
| if latents_key not in latent_data: | |
| print(f"[ERROR] Station {station_id} not found in latent bank!") | |
| available = [k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')] | |
| print(f"Available stations: {available}") | |
| return None, None | |
| latents = latent_data[latents_key] # (N_samples, seq_len, d_model) | |
| means = latent_data[means_key] # (N_samples, seq_len, d_model) | |
| stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model) | |
| print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}") | |
| print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...") | |
| generated_signals = [] | |
| real_names_used = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for i in range(num_samples): | |
| # Bootstrap: randomly select k latent vectors with replacement | |
| k = min(args.pcgen_k, len(latents)) | |
| selected_indices = np.random.choice(len(latents), size=k, replace=True) | |
| # Mix latent features (average) | |
| selected_latents = latents[selected_indices] # (k, seq_len, d_model) | |
| selected_means = means[selected_indices] # (k, seq_len, d_model) | |
| selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model) | |
| mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model) | |
| mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model) | |
| mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model) | |
| # Convert to torch tensors | |
| mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model) | |
| means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model) | |
| stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model) | |
| if args.use_gpu: | |
| mixed_features_torch = mixed_features_torch.cuda() | |
| means_b = means_b.cuda() | |
| stdev_b = stdev_b.cuda() | |
| # Decode | |
| xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b) | |
| # Store - transpose to (3, 6000) | |
| generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) β (3, 6000) | |
| generated_signals.append(generated_np) | |
| # Track which latent indices were used | |
| real_names_used.append([f"latent_{idx}" for idx in selected_indices]) | |
| if (i + 1) % 10 == 0: | |
| print(f" Generated {i + 1}/{num_samples} samples...") | |
| return np.array(generated_signals), real_names_used | |
| def save_generated_samples(generated_signals, real_names, station_id, output_dir): | |
| """Save generated samples to NPZ file (NO PLOTTING).""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Save timeseries NPZ | |
| output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz') | |
| np.savez_compressed( | |
| output_path, | |
| generated_signals=generated_signals, | |
| signals_generated=generated_signals, # Alias for compatibility | |
| real_names=real_names, | |
| station_id=station_id, | |
| station=station_id, # Alias for compatibility | |
| ) | |
| print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}") | |
| return output_path | |
| def plot_signal_for_display(signal): | |
| """ | |
| Plot a single 3-component seismic signal for Gradio display. | |
| Args: | |
| signal: (3, 6000) array [E, N, Z] | |
| Returns: | |
| PIL Image | |
| """ | |
| fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) | |
| component_names = ['East', 'North', 'Vertical'] | |
| colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] | |
| t = np.arange(signal.shape[1]) / 100.0 # 100 Hz sampling | |
| for idx, (ax, name, color) in enumerate(zip(axes, component_names, colors)): | |
| ax.plot(t, signal[idx], color=color, linewidth=0.5, alpha=0.8) | |
| ax.set_ylabel(f'{name}\n(cm/sΒ²)', fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.set_xlim(0, 60) | |
| axes[-1].set_xlabel('Time (s)', fontsize=11) | |
| fig.suptitle('Generated Seismic Signal (3 Components)', fontsize=13, fontweight='bold') | |
| plt.tight_layout() | |
| # Convert to PIL Image | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # =========================== | |
| # Gradio Interface Functions | |
| # =========================== | |
| def generate_samples_interface(station_id, num_samples, progress=gr.Progress()): | |
| """ | |
| Generate samples for Gradio interface. | |
| Args: | |
| station_id: Station ID (e.g., '0205') | |
| num_samples: Number of samples to generate | |
| progress: Gradio progress tracker | |
| Returns: | |
| status_message: Generation status | |
| npz_path: Path to saved NPZ file | |
| sample_plot: Preview plot of first generated sample | |
| """ | |
| try: | |
| # Convert num_samples to int (Gradio might pass float) | |
| num_samples = int(num_samples) | |
| print(f"[DEBUG] Requested num_samples: {num_samples} (type: {type(num_samples)})") | |
| progress(0, desc="Checking model files...") | |
| # Download model files from HF Hub if needed | |
| try: | |
| checkpoint_path, latent_bank_path = download_model_files() | |
| except Exception as e: | |
| error_msg = f"β Error downloading model files:\n{str(e)}\n\n" | |
| error_msg += "Please ensure model files are uploaded to:\n" | |
| error_msg += f"https://huggingface.co/{HF_REPO_ID}" | |
| return error_msg, None, None | |
| output_dir = 'generated_outputs' | |
| progress(0.1, desc="Loading model...") | |
| # Load model | |
| args = SimpleArgs() | |
| model = load_model(checkpoint_path, args) | |
| progress(0.3, desc=f"Generating {num_samples} samples...") | |
| # Generate samples | |
| generated_signals, real_names = generate_samples_from_latent_bank( | |
| model, latent_bank_path, station_id, num_samples, args | |
| ) | |
| if generated_signals is None: | |
| return f"β Error: Failed to generate samples for station {station_id}", None, None | |
| progress(0.8, desc="Saving NPZ file...") | |
| # Save NPZ (NO PLOTTING) | |
| npz_path = save_generated_samples(generated_signals, real_names, station_id, output_dir) | |
| progress(0.95, desc="Creating preview plot...") | |
| # Create preview plot for first sample | |
| sample_plot = plot_signal_for_display(generated_signals[0]) | |
| progress(1.0, desc="Done!") | |
| # Verify actual number of samples generated | |
| actual_count = len(generated_signals) | |
| status_msg = f"β Successfully generated {actual_count} samples for station {station_id}!\n" | |
| status_msg += f"π Requested: {num_samples} samples\n" | |
| status_msg += f"π Saved to: {npz_path}\n" | |
| status_msg += f"π Preview of first generated sample shown below." | |
| return status_msg, npz_path, sample_plot | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error during generation:\n{str(e)}\n\n{traceback.format_exc()}" | |
| return error_msg, None, None | |
| def load_and_display_npz(npz_file, sample_idx): | |
| """ | |
| Load NPZ file and display a specific sample. | |
| Args: | |
| npz_file: Path to NPZ file | |
| sample_idx: Index of sample to display (0-based) | |
| Returns: | |
| status_message: Load status | |
| sample_plot: Plot of selected sample | |
| """ | |
| try: | |
| if npz_file is None: | |
| return "β οΈ No NPZ file provided", None | |
| # Load NPZ | |
| data = np.load(npz_file) | |
| generated_signals = data['generated_signals'] | |
| if sample_idx < 0 or sample_idx >= len(generated_signals): | |
| return f"β οΈ Sample index {sample_idx} out of range (0-{len(generated_signals)-1})", None | |
| # Plot selected sample | |
| sample_plot = plot_signal_for_display(generated_signals[sample_idx]) | |
| status_msg = f"β Loaded NPZ with {len(generated_signals)} samples\n" | |
| status_msg += f"π Displaying sample #{sample_idx}" | |
| return status_msg, sample_plot | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"β Error loading NPZ:\n{str(e)}\n\n{traceback.format_exc()}" | |
| return error_msg, None | |
| # =========================== | |
| # Gradio App | |
| # =========================== | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| with gr.Blocks(title="TimesNet-Gen Demo: Station-Specific Seismic Generator") as demo: | |
| gr.Markdown(""" | |
| # π TimesNet-Gen Demo: Station-Specific Seismic Sample Generator | |
| For more detailed information: https://arxiv.org/abs/2512.04694 | |
| Generate realistic synthetic seismic signals with station-specific characteristics. | |
| **Instructions:** | |
| 1. Select a station ID (5 fine-tuned stations available) | |
| 2. Choose number of samples to generate | |
| 3. Click "Generate Samples" and wait | |
| 4. Preview generated samples or download NPZ file | |
| """) | |
| with gr.Tab("Generate Samples"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| station_dropdown = gr.Dropdown( | |
| choices=['0205', '1716', '2020', '3130', '4628'], | |
| value='0205', | |
| label="Station ID", | |
| info="Select target station" | |
| ) | |
| num_samples_slider = gr.Slider( | |
| minimum=1, | |
| maximum=200, | |
| value=50, | |
| step=1, | |
| label="Number of Samples", | |
| info="How many samples to generate" | |
| ) | |
| generate_btn = gr.Button("π Generate Samples", variant="primary") | |
| with gr.Column(scale=2): | |
| status_text = gr.Textbox( | |
| label="Status", | |
| lines=5, | |
| interactive=False | |
| ) | |
| npz_file_output = gr.File( | |
| label="Generated NPZ File", | |
| interactive=False | |
| ) | |
| gr.Markdown("### Preview (First Generated Sample)") | |
| preview_plot = gr.Image(label="Sample Preview") | |
| generate_btn.click( | |
| fn=generate_samples_interface, | |
| inputs=[station_dropdown, num_samples_slider], | |
| outputs=[status_text, npz_file_output, preview_plot] | |
| ) | |
| with gr.Tab("View Saved Samples"): | |
| gr.Markdown("### Load and view samples from saved NPZ file") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| npz_upload = gr.File( | |
| label="Upload NPZ File", | |
| file_types=['.npz'] | |
| ) | |
| sample_idx_slider = gr.Slider( | |
| minimum=0, | |
| maximum=199, | |
| value=0, | |
| step=1, | |
| label="Sample Index", | |
| info="Which sample to display (0-based)" | |
| ) | |
| load_btn = gr.Button("π Load and Display", variant="secondary") | |
| with gr.Column(scale=2): | |
| load_status = gr.Textbox( | |
| label="Status", | |
| lines=3, | |
| interactive=False | |
| ) | |
| display_plot = gr.Image(label="Sample Display") | |
| load_btn.click( | |
| fn=load_and_display_npz, | |
| inputs=[npz_upload, sample_idx_slider], | |
| outputs=[load_status, display_plot] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Model:** TimesNet-PointCloud | |
| **Method:** Bootstrap aggregation from latent bank | |
| **Stations:** 5 fine-tuned Turkish strong-motion stations | |
| **Output:** 3-component acceleration signals (E, N, Z) @ 100 Hz | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |