Spaces:
Runtime error
Runtime error
| """ | |
| Candlestick Chart Diffusion Model - Hugging Face Spaces App | |
| Generates candlestick chart images from text prompts | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from pathlib import Path | |
| import math | |
| from tqdm import tqdm | |
| import json | |
| import random | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from einops import rearrange | |
| # ============== Model Components ============== | |
| class SinusoidalPositionEmbeddings(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, time): | |
| device = time.device | |
| half_dim = self.dim // 2 | |
| embeddings = math.log(10000) / (half_dim - 1) | |
| embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
| embeddings = time[:, None] * embeddings[None, :] | |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
| return embeddings | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, time_emb_dim, groups=8): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) | |
| self.norm1 = nn.GroupNorm(groups, in_channels) | |
| self.norm2 = nn.GroupNorm(groups, out_channels) | |
| self.time_mlp = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(time_emb_dim, out_channels * 2) | |
| ) | |
| self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() | |
| def forward(self, x, time_emb): | |
| h = F.silu(self.norm1(x)) | |
| h = self.conv1(h) | |
| time_emb = self.time_mlp(time_emb) | |
| time_emb = rearrange(time_emb, "b c -> b c 1 1") | |
| scale, shift = time_emb.chunk(2, dim=1) | |
| h = h * (1 + scale) + shift | |
| h = F.silu(self.norm2(h)) | |
| h = self.conv2(h) | |
| return h + self.residual_conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, channels, num_heads=4): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = channels // num_heads | |
| self.norm = nn.GroupNorm(8, channels) | |
| self.qkv = nn.Conv2d(channels, channels * 3, 1) | |
| self.proj = nn.Conv2d(channels, channels, 1) | |
| self.scale = self.head_dim ** -0.5 | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| x_norm = self.norm(x) | |
| qkv = self.qkv(x_norm) | |
| q, k, v = qkv.chunk(3, dim=1) | |
| q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) | |
| k = rearrange(k, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) | |
| v = rearrange(v, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) | |
| attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale | |
| attn = F.softmax(attn, dim=-1) | |
| out = torch.einsum("bhij,bhjd->bhid", attn, v) | |
| out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w) | |
| return x + self.proj(out) | |
| class CrossAttentionBlock(nn.Module): | |
| def __init__(self, channels, context_dim, num_heads=4): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = channels // num_heads | |
| self.norm = nn.GroupNorm(8, channels) | |
| self.norm_context = nn.LayerNorm(context_dim) | |
| self.to_q = nn.Conv2d(channels, channels, 1) | |
| self.to_k = nn.Linear(context_dim, channels) | |
| self.to_v = nn.Linear(context_dim, channels) | |
| self.proj = nn.Conv2d(channels, channels, 1) | |
| self.scale = self.head_dim ** -0.5 | |
| def forward(self, x, context): | |
| b, c, h, w = x.shape | |
| x_norm = self.norm(x) | |
| context = self.norm_context(context) | |
| q = self.to_q(x_norm) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads) | |
| k = rearrange(k, "b n (heads d) -> b heads n d", heads=self.num_heads) | |
| v = rearrange(v, "b n (heads d) -> b heads n d", heads=self.num_heads) | |
| attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale | |
| attn = F.softmax(attn, dim=-1) | |
| out = torch.einsum("bhij,bhjd->bhid", attn, v) | |
| out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w) | |
| return x + self.proj(out) | |
| class DownBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, downsample=True): | |
| super().__init__() | |
| self.res1 = ResidualBlock(in_ch, out_ch, time_dim) | |
| self.res2 = ResidualBlock(out_ch, out_ch, time_dim) | |
| self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity() | |
| self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None | |
| self.downsample = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1) if downsample else nn.Identity() | |
| def forward(self, x, time_emb, context): | |
| x = self.res1(x, time_emb) | |
| x = self.res2(x, time_emb) | |
| if not isinstance(self.attn, nn.Identity): | |
| x = self.attn(x) | |
| x = self.cross_attn(x, context) | |
| skip = x | |
| x = self.downsample(x) | |
| return x, skip | |
| class UpBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, upsample=True): | |
| super().__init__() | |
| self.res1 = ResidualBlock(in_ch + out_ch, out_ch, time_dim) | |
| self.res2 = ResidualBlock(out_ch, out_ch, time_dim) | |
| self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity() | |
| self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None | |
| self.upsample = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1) | |
| ) if upsample else nn.Identity() | |
| def forward(self, x, skip, time_emb, context): | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.res1(x, time_emb) | |
| x = self.res2(x, time_emb) | |
| if not isinstance(self.attn, nn.Identity): | |
| x = self.attn(x) | |
| x = self.cross_attn(x, context) | |
| x = self.upsample(x) | |
| return x | |
| class ConditionalUNet(nn.Module): | |
| def __init__(self, in_ch=3, out_ch=3, base_ch=64, channel_mults=(1, 2, 4), context_dim=256): | |
| super().__init__() | |
| time_dim = base_ch * 4 | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPositionEmbeddings(base_ch), | |
| nn.Linear(base_ch, time_dim), | |
| nn.SiLU(), | |
| nn.Linear(time_dim, time_dim) | |
| ) | |
| self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1) | |
| # Downsampling | |
| self.down_blocks = nn.ModuleList() | |
| channels = [base_ch] | |
| in_ch_block = base_ch | |
| for i, mult in enumerate(channel_mults): | |
| out_ch_block = base_ch * mult | |
| is_last = i == len(channel_mults) - 1 | |
| has_attn = mult >= 2 | |
| self.down_blocks.append( | |
| DownBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last) | |
| ) | |
| channels.append(out_ch_block) | |
| in_ch_block = out_ch_block | |
| # Middle | |
| self.mid_res1 = ResidualBlock(in_ch_block, in_ch_block, time_dim) | |
| self.mid_attn = AttentionBlock(in_ch_block) | |
| self.mid_cross = CrossAttentionBlock(in_ch_block, context_dim) | |
| self.mid_res2 = ResidualBlock(in_ch_block, in_ch_block, time_dim) | |
| # Upsampling | |
| self.up_blocks = nn.ModuleList() | |
| for i, mult in enumerate(reversed(channel_mults)): | |
| out_ch_block = base_ch * mult | |
| is_last = i == len(channel_mults) - 1 | |
| has_attn = mult >= 2 | |
| self.up_blocks.append( | |
| UpBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last) | |
| ) | |
| in_ch_block = out_ch_block | |
| self.norm_out = nn.GroupNorm(8, base_ch) | |
| self.conv_out = nn.Conv2d(base_ch, 3, 3, padding=1) | |
| self.channels = channels | |
| def forward(self, x, time, context): | |
| t = self.time_mlp(time) | |
| x = self.conv_in(x) | |
| skips = [] | |
| for block in self.down_blocks: | |
| x, skip = block(x, t, context) | |
| skips.append(skip) | |
| x = self.mid_res1(x, t) | |
| x = self.mid_attn(x) | |
| x = self.mid_cross(x, context) | |
| x = self.mid_res2(x, t) | |
| for block in self.up_blocks: | |
| skip = skips.pop() | |
| x = block(x, skip, t, context) | |
| x = F.silu(self.norm_out(x)) | |
| return self.conv_out(x) | |
| # ============== Text Encoder ============== | |
| class SimpleTextEncoder(nn.Module): | |
| def __init__(self, vocab_size=200, embed_dim=256, max_len=64): | |
| super().__init__() | |
| self.max_len = max_len | |
| self.embed_dim = embed_dim | |
| self.embed = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_embed = nn.Embedding(max_len, embed_dim) | |
| self.transformer = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=512, batch_first=True), | |
| num_layers=2 | |
| ) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| chars = " abcdefghijklmnopqrstuvwxyz0123456789-_.,;:!?()[]{}'\"/\\@#$%^&*+=<>~`" | |
| self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} | |
| self.char_to_idx["<pad>"] = 0 | |
| def tokenize(self, texts, device): | |
| batch = [] | |
| for text in texts: | |
| text = text.lower()[:self.max_len] | |
| tokens = [self.char_to_idx.get(c, 0) for c in text] | |
| tokens += [0] * (self.max_len - len(tokens)) | |
| batch.append(tokens) | |
| return torch.tensor(batch, device=device) | |
| def forward(self, texts, device): | |
| tokens = self.tokenize(texts, device) | |
| pos = torch.arange(self.max_len, device=device).unsqueeze(0) | |
| x = self.embed(tokens) + self.pos_embed(pos) | |
| x = self.transformer(x) | |
| return self.norm(x) | |
| def get_uncond(self, batch_size, device): | |
| return self.forward([""] * batch_size, device) | |
| # ============== Diffusion ============== | |
| class GaussianDiffusion: | |
| def __init__(self, timesteps=1000, device="cuda"): | |
| self.timesteps = timesteps | |
| self.device = device | |
| betas = self._cosine_schedule(timesteps) | |
| alphas = 1 - betas | |
| alpha_cum = torch.cumprod(alphas, dim=0) | |
| self.betas = betas.to(device) | |
| self.alphas = alphas.to(device) | |
| self.alpha_cum = alpha_cum.to(device) | |
| self.sqrt_alpha_cum = torch.sqrt(alpha_cum).to(device) | |
| self.sqrt_one_minus_alpha_cum = torch.sqrt(1 - alpha_cum).to(device) | |
| def _cosine_schedule(self, timesteps, s=0.008): | |
| steps = timesteps + 1 | |
| x = torch.linspace(0, timesteps, steps) | |
| alpha_cum = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 | |
| alpha_cum = alpha_cum / alpha_cum[0] | |
| betas = 1 - (alpha_cum[1:] / alpha_cum[:-1]) | |
| return torch.clamp(betas, 0.0001, 0.999) | |
| def add_noise(self, x, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x) | |
| sqrt_alpha = self.sqrt_alpha_cum[t].view(-1, 1, 1, 1) | |
| sqrt_one_minus = self.sqrt_one_minus_alpha_cum[t].view(-1, 1, 1, 1) | |
| return sqrt_alpha * x + sqrt_one_minus * noise, noise | |
| def loss(self, model, x, context): | |
| batch_size = x.shape[0] | |
| t = torch.randint(0, self.timesteps, (batch_size,), device=self.device) | |
| noise = torch.randn_like(x) | |
| x_noisy, _ = self.add_noise(x, t, noise) | |
| pred = model(x_noisy, t.float(), context) | |
| return F.mse_loss(pred, noise) | |
| def sample(self, model, context, context_uncond=None, shape=(1, 3, 128, 128), | |
| steps=50, guidance_scale=7.5, progress_callback=None): | |
| x = torch.randn(shape, device=self.device) | |
| step_size = self.timesteps // steps | |
| timesteps = list(range(0, self.timesteps, step_size))[::-1] | |
| for i, t in enumerate(timesteps): | |
| t_batch = torch.full((shape[0],), t, device=self.device, dtype=torch.long) | |
| pred = model(x, t_batch.float(), context) | |
| if guidance_scale > 1.0 and context_uncond is not None: | |
| pred_uncond = model(x, t_batch.float(), context_uncond) | |
| pred = pred_uncond + guidance_scale * (pred - pred_uncond) | |
| alpha = self.alphas[t] | |
| alpha_cum = self.alpha_cum[t] | |
| beta = self.betas[t] | |
| x = (1 / torch.sqrt(alpha)) * (x - (beta / self.sqrt_one_minus_alpha_cum[t]) * pred) | |
| if t > 0: | |
| noise = torch.randn_like(x) | |
| x = x + torch.sqrt(beta) * noise | |
| if progress_callback: | |
| progress_callback((i + 1) / len(timesteps)) | |
| return x | |
| # ============== Dataset ============== | |
| class ChartDataset(Dataset): | |
| def __init__(self, data_dir, image_size=128, split="train"): | |
| self.data_dir = Path(data_dir) | |
| self.image_size = image_size | |
| with open(self.data_dir / "labels.json") as f: | |
| self.labels = json.load(f) | |
| all_files = sorted(list(self.labels.keys())) | |
| split_idx = int(len(all_files) * 0.9) | |
| self.files = all_files[:split_idx] if split == "train" else all_files[split_idx:] | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| ]) | |
| def __len__(self): | |
| return len(self.files) | |
| def __getitem__(self, idx): | |
| filename = self.files[idx] | |
| image = Image.open(self.data_dir / "images" / filename).convert("RGB") | |
| image = self.transform(image) | |
| text = self.labels[filename] | |
| if random.random() < 0.1: | |
| text = "" | |
| return image, text | |
| def collate_fn(batch): | |
| images = torch.stack([b[0] for b in batch]) | |
| texts = [b[1] for b in batch] | |
| return images, texts | |
| # ============== Global State ============== | |
| MODEL = None | |
| TEXT_ENCODER = None | |
| DIFFUSION = None | |
| DEVICE = None | |
| CONFIG = None | |
| def load_model(checkpoint_path=None): | |
| global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {DEVICE}") | |
| # Default config | |
| CONFIG = { | |
| "base_channels": 64, | |
| "channel_mults": (1, 2, 4), | |
| "context_dim": 256, | |
| "image_size": 128, | |
| "timesteps": 1000 | |
| } | |
| # Load checkpoint if exists | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| print(f"Loading checkpoint from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| if "config" in checkpoint: | |
| CONFIG.update(checkpoint["config"]) | |
| # Create models | |
| TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE) | |
| MODEL = ConditionalUNet( | |
| base_ch=CONFIG["base_channels"], | |
| channel_mults=CONFIG["channel_mults"], | |
| context_dim=CONFIG["context_dim"] | |
| ).to(DEVICE) | |
| # Load weights if available | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| MODEL.load_state_dict(checkpoint["model_state_dict"]) | |
| if "text_encoder_state_dict" in checkpoint: | |
| TEXT_ENCODER.load_state_dict(checkpoint["text_encoder_state_dict"]) | |
| print("Model weights loaded!") | |
| MODEL.eval() | |
| DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE) | |
| num_params = sum(p.numel() for p in MODEL.parameters()) | |
| print(f"Model parameters: {num_params:,}") | |
| return True | |
| def generate_dataset_ui(num_samples, image_size): | |
| """Generate training dataset.""" | |
| try: | |
| import os | |
| import json | |
| import random | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| import io | |
| output_dir = "./dataset" | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, "images"), exist_ok=True) | |
| bg_color = "#1a1a2e" | |
| bullish_color = "#00ff88" | |
| bearish_color = "#ff4466" | |
| num_candles = 20 | |
| def generate_candles(pattern, vol): | |
| candles = [] | |
| price = 100 if pattern != "bearish" else 150 | |
| for i in range(num_candles): | |
| if pattern == "bullish": | |
| trend = random.uniform(0.5, 2.0) | |
| o = price + random.gauss(0, vol) | |
| c = o + random.uniform(0, vol*2) + trend | |
| elif pattern == "bearish": | |
| trend = random.uniform(0.5, 2.0) | |
| o = price + random.gauss(0, vol) | |
| c = o - random.uniform(0, vol*2) - trend | |
| else: # sideways | |
| o = price + random.gauss(0, vol) | |
| c = o + random.gauss(0, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return candles | |
| def render(candles): | |
| fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100) | |
| fig.patch.set_facecolor(bg_color) | |
| ax.set_facecolor(bg_color) | |
| highs = [c["h"] for c in candles] | |
| lows = [c["l"] for c in candles] | |
| price_min, price_max = min(lows)*0.98, max(highs)*1.02 | |
| for i, c in enumerate(candles): | |
| color = bullish_color if c["c"] >= c["o"] else bearish_color | |
| ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1) | |
| body_bottom = min(c["o"], c["c"]) | |
| body_height = abs(c["c"] - c["o"]) or 0.1 | |
| rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color) | |
| ax.add_patch(rect) | |
| ax.set_xlim(-1, len(candles)) | |
| ax.set_ylim(price_min, price_max) | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf).convert("RGB") | |
| return img.resize((image_size, image_size), Image.Resampling.LANCZOS) | |
| patterns = ["bullish", "bearish", "sideways"] | |
| volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0} | |
| labels = {} | |
| for i in range(int(num_samples)): | |
| pattern = random.choice(patterns) | |
| vol_name = random.choice(list(volatilities.keys())) | |
| vol = volatilities[vol_name] | |
| candles = generate_candles(pattern, vol) | |
| img = render(candles) | |
| filename = f"chart_{i:06d}.png" | |
| img.save(os.path.join(output_dir, "images", filename)) | |
| labels[filename] = f"{pattern} trend {vol_name} volatility" | |
| if i % 500 == 0: | |
| print(f"Generated {i}/{num_samples}") | |
| with open(os.path.join(output_dir, "labels.json"), "w") as f: | |
| json.dump(labels, f) | |
| return f"β Generated {num_samples} samples in ./dataset" | |
| except Exception as e: | |
| return f"β Failed: {str(e)}" | |
| # ============== Gradio Interface ============== | |
| def generate_chart(prompt, num_steps, guidance_scale, seed): | |
| global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG | |
| if MODEL is None: | |
| return None, "β Model not loaded! Train first or load a checkpoint." | |
| if not prompt.strip(): | |
| return None, "β Please enter a prompt!" | |
| try: | |
| if seed >= 0: | |
| torch.manual_seed(seed) | |
| if DEVICE.type == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| with torch.no_grad(): | |
| context = TEXT_ENCODER([prompt], DEVICE) | |
| context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE) | |
| samples = DIFFUSION.sample( | |
| MODEL, context, context_uncond, | |
| shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]), | |
| steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| progress_callback=None | |
| ) | |
| # Convert to image | |
| samples = (samples + 1) / 2 | |
| samples = samples.clamp(0, 1) | |
| samples = (samples * 255).to(torch.uint8) | |
| img_array = samples[0].permute(1, 2, 0).cpu().numpy() | |
| img = Image.fromarray(img_array) | |
| return img, f"β Generated successfully!" | |
| except Exception as e: | |
| return None, f"β Error: {str(e)}" | |
| def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name): | |
| global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG | |
| try: | |
| # Setup | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CONFIG = { | |
| "base_channels": 64, | |
| "channel_mults": (1, 2, 4), | |
| "context_dim": 256, | |
| "image_size": image_size, | |
| "timesteps": 1000 | |
| } | |
| # Create models | |
| TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE) | |
| MODEL = ConditionalUNet( | |
| base_ch=CONFIG["base_channels"], | |
| channel_mults=CONFIG["channel_mults"], | |
| context_dim=CONFIG["context_dim"] | |
| ).to(DEVICE) | |
| DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE) | |
| num_params = sum(p.numel() for p in MODEL.parameters()) | |
| # Dataset | |
| train_dataset = ChartDataset(data_path, image_size=image_size, split="train") | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=batch_size, shuffle=True, | |
| num_workers=2, pin_memory=True, drop_last=True, collate_fn=collate_fn | |
| ) | |
| # Optimizer | |
| optimizer = torch.optim.AdamW( | |
| list(MODEL.parameters()) + list(TEXT_ENCODER.parameters()), | |
| lr=learning_rate | |
| ) | |
| # Training | |
| MODEL.train() | |
| TEXT_ENCODER.train() | |
| logs = [f"π Training started on {DEVICE}"] | |
| logs.append(f"π Model parameters: {num_params:,}") | |
| logs.append(f"π Training samples: {len(train_dataset)}") | |
| logs.append("-" * 40) | |
| total_steps = epochs * len(train_loader) | |
| current_step = 0 | |
| for epoch in range(epochs): | |
| epoch_loss = 0 | |
| for images, texts in train_loader: | |
| images = images.to(DEVICE) | |
| context = TEXT_ENCODER(texts, DEVICE) | |
| optimizer.zero_grad() | |
| loss = DIFFUSION.loss(MODEL, images, context) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| current_step += 1 | |
| avg_loss = epoch_loss / len(train_loader) | |
| logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}") | |
| # Save model | |
| MODEL.eval() | |
| os.makedirs("checkpoints", exist_ok=True) | |
| save_path = f"checkpoints/{save_name}.pt" | |
| torch.save({ | |
| "model_state_dict": MODEL.state_dict(), | |
| "text_encoder_state_dict": TEXT_ENCODER.state_dict(), | |
| "config": CONFIG | |
| }, save_path) | |
| logs.append("-" * 40) | |
| logs.append(f"β Model saved to {save_path}") | |
| return "\n".join(logs) | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}" | |
| def load_checkpoint(checkpoint_file): | |
| if checkpoint_file is None: | |
| return "β No file selected" | |
| try: | |
| load_model(checkpoint_file.name) | |
| return f"β Model loaded from {checkpoint_file.name}" | |
| except Exception as e: | |
| return f"β Failed to load: {str(e)}" | |
| # ============== Gradio UI ============== | |
| def create_demo(): | |
| with gr.Blocks(title="Candlestick Chart Generator") as demo: | |
| gr.Markdown(""" | |
| # π Candlestick Chart Diffusion Generator | |
| Generate candlestick chart images from text descriptions using a diffusion model. | |
| **Steps:** | |
| 1. Upload your dataset (or use the generator script to create one) | |
| 2. Train the model | |
| 3. Generate charts from text prompts! | |
| """) | |
| with gr.Tabs(): | |
| # Data Generation Tab | |
| with gr.TabItem("π Generate Data"): | |
| gr.Markdown(""" | |
| ### Generate Training Dataset | |
| Create synthetic candlestick chart images for training. | |
| **Run this first before training!** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| num_samples = gr.Slider(1000, 50000, value=10000, step=1000, label="Number of Samples") | |
| data_image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size") | |
| generate_data_btn = gr.Button("π Generate Dataset", variant="primary") | |
| with gr.Column(): | |
| data_status = gr.Textbox(label="Status", lines=5, interactive=False) | |
| generate_data_btn.click( | |
| generate_dataset_ui, | |
| inputs=[num_samples, data_image_size], | |
| outputs=[data_status] | |
| ) | |
| # Generation Tab | |
| with gr.TabItem("π¨ Generate"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="e.g., bullish trend with high volatility", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps") | |
| guidance = gr.Slider(1, 20, value=7.5, step=0.5, label="Guidance Scale") | |
| seed_input = gr.Number(label="Seed (-1 for random)", value=-1) | |
| generate_btn = gr.Button("π¨ Generate", variant="primary") | |
| gen_status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Example Prompts") | |
| gr.Examples( | |
| examples=[ | |
| ["bullish trend with high volatility"], | |
| ["bearish reversal pattern"], | |
| ["double bottom formation low volatility"], | |
| ["sideways market consolidation"], | |
| ["head and shoulders pattern"], | |
| ["strong upward trend green candles"], | |
| ], | |
| inputs=[prompt_input] | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Chart", type="pil") | |
| generate_btn.click( | |
| generate_chart, | |
| inputs=[prompt_input, num_steps, guidance, seed_input], | |
| outputs=[output_image, gen_status] | |
| ) | |
| # Training Tab | |
| with gr.TabItem("ποΈ Train"): | |
| gr.Markdown(""" | |
| ### Training Configuration | |
| Upload your dataset folder containing: | |
| - `images/` folder with chart images | |
| - `labels.json` with text descriptions | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| data_path = gr.Textbox(label="Dataset Path", value="./dataset") | |
| epochs = gr.Slider(1, 200, value=50, step=1, label="Epochs") | |
| batch_size = gr.Slider(1, 64, value=16, step=1, label="Batch Size") | |
| learning_rate = gr.Number(label="Learning Rate", value=1e-4) | |
| image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size") | |
| save_name = gr.Textbox(label="Model Name", value="candlestick_model") | |
| train_btn = gr.Button("π Start Training", variant="primary") | |
| with gr.Column(): | |
| train_logs = gr.Textbox(label="Training Logs", lines=20, interactive=False) | |
| train_btn.click( | |
| train_model, | |
| inputs=[data_path, epochs, batch_size, learning_rate, image_size, save_name], | |
| outputs=[train_logs] | |
| ) | |
| # Load Model Tab | |
| with gr.TabItem("π Load Model"): | |
| gr.Markdown("### Load a trained checkpoint") | |
| checkpoint_upload = gr.File(label="Upload Checkpoint (.pt file)") | |
| load_btn = gr.Button("Load Model") | |
| load_status = gr.Textbox(label="Status", interactive=False) | |
| load_btn.click( | |
| load_checkpoint, | |
| inputs=[checkpoint_upload], | |
| outputs=[load_status] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### Tips | |
| - **Training**: Use at least 5000 samples and 50+ epochs for good results | |
| - **Guidance Scale**: Higher values (7-12) follow prompts more closely | |
| - **Steps**: 50 steps is a good balance between speed and quality | |
| """) | |
| return demo | |
| # ============== Main ============== | |
| if __name__ == "__main__": | |
| # Try to load existing checkpoint | |
| if os.path.exists("checkpoints/candlestick_model.pt"): | |
| load_model("checkpoints/candlestick_model.pt") | |
| demo = create_demo() | |
| demo.launch() | |