""" Candlestick Chart Diffusion Model - Hugging Face Spaces App Generates candlestick chart images from text prompts """ import os import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from PIL import Image import numpy as np from pathlib import Path import math from tqdm import tqdm import json import random from torch.utils.data import Dataset, DataLoader from torchvision import transforms from einops import rearrange # ============== Model Components ============== class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_emb_dim, groups=8): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) self.norm1 = nn.GroupNorm(groups, in_channels) self.norm2 = nn.GroupNorm(groups, out_channels) self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, out_channels * 2) ) self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() def forward(self, x, time_emb): h = F.silu(self.norm1(x)) h = self.conv1(h) time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, "b c -> b c 1 1") scale, shift = time_emb.chunk(2, dim=1) h = h * (1 + scale) + shift h = F.silu(self.norm2(h)) h = self.conv2(h) return h + self.residual_conv(x) class AttentionBlock(nn.Module): def __init__(self, channels, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = channels // num_heads self.norm = nn.GroupNorm(8, channels) self.qkv = nn.Conv2d(channels, channels * 3, 1) self.proj = nn.Conv2d(channels, channels, 1) self.scale = self.head_dim ** -0.5 def forward(self, x): b, c, h, w = x.shape x_norm = self.norm(x) qkv = self.qkv(x_norm) q, k, v = qkv.chunk(3, dim=1) q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) k = rearrange(k, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) v = rearrange(v, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale attn = F.softmax(attn, dim=-1) out = torch.einsum("bhij,bhjd->bhid", attn, v) out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w) return x + self.proj(out) class CrossAttentionBlock(nn.Module): def __init__(self, channels, context_dim, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = channels // num_heads self.norm = nn.GroupNorm(8, channels) self.norm_context = nn.LayerNorm(context_dim) self.to_q = nn.Conv2d(channels, channels, 1) self.to_k = nn.Linear(context_dim, channels) self.to_v = nn.Linear(context_dim, channels) self.proj = nn.Conv2d(channels, channels, 1) self.scale = self.head_dim ** -0.5 def forward(self, x, context): b, c, h, w = x.shape x_norm = self.norm(x) context = self.norm_context(context) q = self.to_q(x_norm) k = self.to_k(context) v = self.to_v(context) q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) k = rearrange(k, "b n (heads d) -> b heads n d", heads=self.num_heads) v = rearrange(v, "b n (heads d) -> b heads n d", heads=self.num_heads) attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale attn = F.softmax(attn, dim=-1) out = torch.einsum("bhij,bhjd->bhid", attn, v) out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w) return x + self.proj(out) class DownBlock(nn.Module): def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, downsample=True): super().__init__() self.res1 = ResidualBlock(in_ch, out_ch, time_dim) self.res2 = ResidualBlock(out_ch, out_ch, time_dim) self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity() self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None self.downsample = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1) if downsample else nn.Identity() def forward(self, x, time_emb, context): x = self.res1(x, time_emb) x = self.res2(x, time_emb) if not isinstance(self.attn, nn.Identity): x = self.attn(x) x = self.cross_attn(x, context) skip = x x = self.downsample(x) return x, skip class UpBlock(nn.Module): def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, upsample=True): super().__init__() self.res1 = ResidualBlock(in_ch + out_ch, out_ch, time_dim) self.res2 = ResidualBlock(out_ch, out_ch, time_dim) self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity() self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None self.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(out_ch, out_ch, 3, padding=1) ) if upsample else nn.Identity() def forward(self, x, skip, time_emb, context): x = torch.cat([x, skip], dim=1) x = self.res1(x, time_emb) x = self.res2(x, time_emb) if not isinstance(self.attn, nn.Identity): x = self.attn(x) x = self.cross_attn(x, context) x = self.upsample(x) return x class ConditionalUNet(nn.Module): def __init__(self, in_ch=3, out_ch=3, base_ch=64, channel_mults=(1, 2, 4), context_dim=256): super().__init__() time_dim = base_ch * 4 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(base_ch), nn.Linear(base_ch, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim) ) self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1) # Downsampling self.down_blocks = nn.ModuleList() channels = [base_ch] in_ch_block = base_ch for i, mult in enumerate(channel_mults): out_ch_block = base_ch * mult is_last = i == len(channel_mults) - 1 has_attn = mult >= 2 self.down_blocks.append( DownBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last) ) channels.append(out_ch_block) in_ch_block = out_ch_block # Middle self.mid_res1 = ResidualBlock(in_ch_block, in_ch_block, time_dim) self.mid_attn = AttentionBlock(in_ch_block) self.mid_cross = CrossAttentionBlock(in_ch_block, context_dim) self.mid_res2 = ResidualBlock(in_ch_block, in_ch_block, time_dim) # Upsampling self.up_blocks = nn.ModuleList() for i, mult in enumerate(reversed(channel_mults)): out_ch_block = base_ch * mult is_last = i == len(channel_mults) - 1 has_attn = mult >= 2 self.up_blocks.append( UpBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last) ) in_ch_block = out_ch_block self.norm_out = nn.GroupNorm(8, base_ch) self.conv_out = nn.Conv2d(base_ch, 3, 3, padding=1) self.channels = channels def forward(self, x, time, context): t = self.time_mlp(time) x = self.conv_in(x) skips = [] for block in self.down_blocks: x, skip = block(x, t, context) skips.append(skip) x = self.mid_res1(x, t) x = self.mid_attn(x) x = self.mid_cross(x, context) x = self.mid_res2(x, t) for block in self.up_blocks: skip = skips.pop() x = block(x, skip, t, context) x = F.silu(self.norm_out(x)) return self.conv_out(x) # ============== Text Encoder ============== class SimpleTextEncoder(nn.Module): def __init__(self, vocab_size=200, embed_dim=256, max_len=64): super().__init__() self.max_len = max_len self.embed_dim = embed_dim self.embed = nn.Embedding(vocab_size, embed_dim) self.pos_embed = nn.Embedding(max_len, embed_dim) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=512, batch_first=True), num_layers=2 ) self.norm = nn.LayerNorm(embed_dim) chars = " abcdefghijklmnopqrstuvwxyz0123456789-_.,;:!?()[]{}'\"/\\@#$%^&*+=<>~`" self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} self.char_to_idx[""] = 0 def tokenize(self, texts, device): batch = [] for text in texts: text = text.lower()[:self.max_len] tokens = [self.char_to_idx.get(c, 0) for c in text] tokens += [0] * (self.max_len - len(tokens)) batch.append(tokens) return torch.tensor(batch, device=device) def forward(self, texts, device): tokens = self.tokenize(texts, device) pos = torch.arange(self.max_len, device=device).unsqueeze(0) x = self.embed(tokens) + self.pos_embed(pos) x = self.transformer(x) return self.norm(x) def get_uncond(self, batch_size, device): return self.forward([""] * batch_size, device) # ============== Diffusion ============== class GaussianDiffusion: def __init__(self, timesteps=1000, device="cuda"): self.timesteps = timesteps self.device = device betas = self._cosine_schedule(timesteps) alphas = 1 - betas alpha_cum = torch.cumprod(alphas, dim=0) self.betas = betas.to(device) self.alphas = alphas.to(device) self.alpha_cum = alpha_cum.to(device) self.sqrt_alpha_cum = torch.sqrt(alpha_cum).to(device) self.sqrt_one_minus_alpha_cum = torch.sqrt(1 - alpha_cum).to(device) def _cosine_schedule(self, timesteps, s=0.008): steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alpha_cum = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alpha_cum = alpha_cum / alpha_cum[0] betas = 1 - (alpha_cum[1:] / alpha_cum[:-1]) return torch.clamp(betas, 0.0001, 0.999) def add_noise(self, x, t, noise=None): if noise is None: noise = torch.randn_like(x) sqrt_alpha = self.sqrt_alpha_cum[t].view(-1, 1, 1, 1) sqrt_one_minus = self.sqrt_one_minus_alpha_cum[t].view(-1, 1, 1, 1) return sqrt_alpha * x + sqrt_one_minus * noise, noise def loss(self, model, x, context): batch_size = x.shape[0] t = torch.randint(0, self.timesteps, (batch_size,), device=self.device) noise = torch.randn_like(x) x_noisy, _ = self.add_noise(x, t, noise) pred = model(x_noisy, t.float(), context) return F.mse_loss(pred, noise) @torch.no_grad() def sample(self, model, context, context_uncond=None, shape=(1, 3, 128, 128), steps=50, guidance_scale=7.5, progress_callback=None): x = torch.randn(shape, device=self.device) step_size = self.timesteps // steps timesteps = list(range(0, self.timesteps, step_size))[::-1] for i, t in enumerate(timesteps): t_batch = torch.full((shape[0],), t, device=self.device, dtype=torch.long) pred = model(x, t_batch.float(), context) if guidance_scale > 1.0 and context_uncond is not None: pred_uncond = model(x, t_batch.float(), context_uncond) pred = pred_uncond + guidance_scale * (pred - pred_uncond) alpha = self.alphas[t] alpha_cum = self.alpha_cum[t] beta = self.betas[t] x = (1 / torch.sqrt(alpha)) * (x - (beta / self.sqrt_one_minus_alpha_cum[t]) * pred) if t > 0: noise = torch.randn_like(x) x = x + torch.sqrt(beta) * noise if progress_callback: progress_callback((i + 1) / len(timesteps)) return x # ============== Dataset ============== class ChartDataset(Dataset): def __init__(self, data_dir, image_size=128, split="train"): self.data_dir = Path(data_dir) self.image_size = image_size with open(self.data_dir / "labels.json") as f: self.labels = json.load(f) all_files = sorted(list(self.labels.keys())) split_idx = int(len(all_files) * 0.9) self.files = all_files[:split_idx] if split == "train" else all_files[split_idx:] self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) def __len__(self): return len(self.files) def __getitem__(self, idx): filename = self.files[idx] image = Image.open(self.data_dir / "images" / filename).convert("RGB") image = self.transform(image) text = self.labels[filename] if random.random() < 0.1: text = "" return image, text def collate_fn(batch): images = torch.stack([b[0] for b in batch]) texts = [b[1] for b in batch] return images, texts # ============== Global State ============== MODEL = None TEXT_ENCODER = None DIFFUSION = None DEVICE = None CONFIG = None def load_model(checkpoint_path=None): global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {DEVICE}") # Default config CONFIG = { "base_channels": 64, "channel_mults": (1, 2, 4), "context_dim": 256, "image_size": 128, "timesteps": 1000 } # Load checkpoint if exists if checkpoint_path and os.path.exists(checkpoint_path): print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=DEVICE) if "config" in checkpoint: CONFIG.update(checkpoint["config"]) # Create models TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE) MODEL = ConditionalUNet( base_ch=CONFIG["base_channels"], channel_mults=CONFIG["channel_mults"], context_dim=CONFIG["context_dim"] ).to(DEVICE) # Load weights if available if checkpoint_path and os.path.exists(checkpoint_path): MODEL.load_state_dict(checkpoint["model_state_dict"]) if "text_encoder_state_dict" in checkpoint: TEXT_ENCODER.load_state_dict(checkpoint["text_encoder_state_dict"]) print("Model weights loaded!") MODEL.eval() DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE) num_params = sum(p.numel() for p in MODEL.parameters()) print(f"Model parameters: {num_params:,}") return True def generate_dataset_ui(num_samples, image_size): """Generate training dataset.""" try: import os import json import random from PIL import Image import matplotlib.pyplot as plt from matplotlib.patches import Rectangle import io output_dir = "./dataset" os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, "images"), exist_ok=True) bg_color = "#1a1a2e" bullish_color = "#00ff88" bearish_color = "#ff4466" num_candles = 20 def generate_candles(pattern, vol): candles = [] price = 100 if pattern != "bearish" else 150 for i in range(num_candles): if pattern == "bullish": trend = random.uniform(0.5, 2.0) o = price + random.gauss(0, vol) c = o + random.uniform(0, vol*2) + trend elif pattern == "bearish": trend = random.uniform(0.5, 2.0) o = price + random.gauss(0, vol) c = o - random.uniform(0, vol*2) - trend else: # sideways o = price + random.gauss(0, vol) c = o + random.gauss(0, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) price = c return candles def render(candles): fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100) fig.patch.set_facecolor(bg_color) ax.set_facecolor(bg_color) highs = [c["h"] for c in candles] lows = [c["l"] for c in candles] price_min, price_max = min(lows)*0.98, max(highs)*1.02 for i, c in enumerate(candles): color = bullish_color if c["c"] >= c["o"] else bearish_color ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1) body_bottom = min(c["o"], c["c"]) body_height = abs(c["c"] - c["o"]) or 0.1 rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color) ax.add_patch(rect) ax.set_xlim(-1, len(candles)) ax.set_ylim(price_min, price_max) ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1) plt.close(fig) buf.seek(0) img = Image.open(buf).convert("RGB") return img.resize((image_size, image_size), Image.Resampling.LANCZOS) patterns = ["bullish", "bearish", "sideways"] volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0} labels = {} for i in range(int(num_samples)): pattern = random.choice(patterns) vol_name = random.choice(list(volatilities.keys())) vol = volatilities[vol_name] candles = generate_candles(pattern, vol) img = render(candles) filename = f"chart_{i:06d}.png" img.save(os.path.join(output_dir, "images", filename)) labels[filename] = f"{pattern} trend {vol_name} volatility" if i % 500 == 0: print(f"Generated {i}/{num_samples}") with open(os.path.join(output_dir, "labels.json"), "w") as f: json.dump(labels, f) return f"✅ Generated {num_samples} samples in ./dataset" except Exception as e: return f"❌ Failed: {str(e)}" # ============== Gradio Interface ============== def generate_chart(prompt, num_steps, guidance_scale, seed): global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG if MODEL is None: return None, "❌ Model not loaded! Train first or load a checkpoint." if not prompt.strip(): return None, "❌ Please enter a prompt!" try: if seed >= 0: torch.manual_seed(seed) if DEVICE.type == "cuda": torch.cuda.manual_seed(seed) with torch.no_grad(): context = TEXT_ENCODER([prompt], DEVICE) context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE) samples = DIFFUSION.sample( MODEL, context, context_uncond, shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]), steps=num_steps, guidance_scale=guidance_scale, progress_callback=None ) # Convert to image samples = (samples + 1) / 2 samples = samples.clamp(0, 1) samples = (samples * 255).to(torch.uint8) img_array = samples[0].permute(1, 2, 0).cpu().numpy() img = Image.fromarray(img_array) return img, f"✅ Generated successfully!" except Exception as e: return None, f"❌ Error: {str(e)}" def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name): global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG try: # Setup DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CONFIG = { "base_channels": 64, "channel_mults": (1, 2, 4), "context_dim": 256, "image_size": image_size, "timesteps": 1000 } # Create models TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE) MODEL = ConditionalUNet( base_ch=CONFIG["base_channels"], channel_mults=CONFIG["channel_mults"], context_dim=CONFIG["context_dim"] ).to(DEVICE) DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE) num_params = sum(p.numel() for p in MODEL.parameters()) # Dataset train_dataset = ChartDataset(data_path, image_size=image_size, split="train") train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True, collate_fn=collate_fn ) # Optimizer optimizer = torch.optim.AdamW( list(MODEL.parameters()) + list(TEXT_ENCODER.parameters()), lr=learning_rate ) # Training MODEL.train() TEXT_ENCODER.train() logs = [f"🚀 Training started on {DEVICE}"] logs.append(f"📊 Model parameters: {num_params:,}") logs.append(f"📁 Training samples: {len(train_dataset)}") logs.append("-" * 40) total_steps = epochs * len(train_loader) current_step = 0 for epoch in range(epochs): epoch_loss = 0 for images, texts in train_loader: images = images.to(DEVICE) context = TEXT_ENCODER(texts, DEVICE) optimizer.zero_grad() loss = DIFFUSION.loss(MODEL, images, context) loss.backward() torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0) optimizer.step() epoch_loss += loss.item() current_step += 1 avg_loss = epoch_loss / len(train_loader) logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}") # Save model MODEL.eval() os.makedirs("checkpoints", exist_ok=True) save_path = f"checkpoints/{save_name}.pt" torch.save({ "model_state_dict": MODEL.state_dict(), "text_encoder_state_dict": TEXT_ENCODER.state_dict(), "config": CONFIG }, save_path) logs.append("-" * 40) logs.append(f"✅ Model saved to {save_path}") return "\n".join(logs) except Exception as e: return f"❌ Training failed: {str(e)}" def load_checkpoint(checkpoint_file): if checkpoint_file is None: return "❌ No file selected" try: load_model(checkpoint_file.name) return f"✅ Model loaded from {checkpoint_file.name}" except Exception as e: return f"❌ Failed to load: {str(e)}" # ============== Gradio UI ============== def create_demo(): with gr.Blocks(title="Candlestick Chart Generator") as demo: gr.Markdown(""" # 📈 Candlestick Chart Diffusion Generator Generate candlestick chart images from text descriptions using a diffusion model. **Steps:** 1. Upload your dataset (or use the generator script to create one) 2. Train the model 3. Generate charts from text prompts! """) with gr.Tabs(): # Data Generation Tab with gr.TabItem("📊 Generate Data"): gr.Markdown(""" ### Generate Training Dataset Create synthetic candlestick chart images for training. **Run this first before training!** """) with gr.Row(): with gr.Column(): num_samples = gr.Slider(1000, 50000, value=10000, step=1000, label="Number of Samples") data_image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size") generate_data_btn = gr.Button("📊 Generate Dataset", variant="primary") with gr.Column(): data_status = gr.Textbox(label="Status", lines=5, interactive=False) generate_data_btn.click( generate_dataset_ui, inputs=[num_samples, data_image_size], outputs=[data_status] ) # Generation Tab with gr.TabItem("🎨 Generate"): with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox( label="Prompt", placeholder="e.g., bullish trend with high volatility", lines=2 ) with gr.Row(): num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps") guidance = gr.Slider(1, 20, value=7.5, step=0.5, label="Guidance Scale") seed_input = gr.Number(label="Seed (-1 for random)", value=-1) generate_btn = gr.Button("🎨 Generate", variant="primary") gen_status = gr.Textbox(label="Status", interactive=False) gr.Markdown("### Example Prompts") gr.Examples( examples=[ ["bullish trend with high volatility"], ["bearish reversal pattern"], ["double bottom formation low volatility"], ["sideways market consolidation"], ["head and shoulders pattern"], ["strong upward trend green candles"], ], inputs=[prompt_input] ) with gr.Column(scale=1): output_image = gr.Image(label="Generated Chart", type="pil") generate_btn.click( generate_chart, inputs=[prompt_input, num_steps, guidance, seed_input], outputs=[output_image, gen_status] ) # Training Tab with gr.TabItem("🏋️ Train"): gr.Markdown(""" ### Training Configuration Upload your dataset folder containing: - `images/` folder with chart images - `labels.json` with text descriptions """) with gr.Row(): with gr.Column(): data_path = gr.Textbox(label="Dataset Path", value="./dataset") epochs = gr.Slider(1, 200, value=50, step=1, label="Epochs") batch_size = gr.Slider(1, 64, value=16, step=1, label="Batch Size") learning_rate = gr.Number(label="Learning Rate", value=1e-4) image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size") save_name = gr.Textbox(label="Model Name", value="candlestick_model") train_btn = gr.Button("🚀 Start Training", variant="primary") with gr.Column(): train_logs = gr.Textbox(label="Training Logs", lines=20, interactive=False) train_btn.click( train_model, inputs=[data_path, epochs, batch_size, learning_rate, image_size, save_name], outputs=[train_logs] ) # Load Model Tab with gr.TabItem("📂 Load Model"): gr.Markdown("### Load a trained checkpoint") checkpoint_upload = gr.File(label="Upload Checkpoint (.pt file)") load_btn = gr.Button("Load Model") load_status = gr.Textbox(label="Status", interactive=False) load_btn.click( load_checkpoint, inputs=[checkpoint_upload], outputs=[load_status] ) gr.Markdown(""" --- ### Tips - **Training**: Use at least 5000 samples and 50+ epochs for good results - **Guidance Scale**: Higher values (7-12) follow prompts more closely - **Steps**: 50 steps is a good balance between speed and quality """) return demo # ============== Main ============== if __name__ == "__main__": # Try to load existing checkpoint if os.path.exists("checkpoints/candlestick_model.pt"): load_model("checkpoints/candlestick_model.pt") demo = create_demo() demo.launch()