Spanicin commited on
Commit
d3932f4
·
verified ·
1 Parent(s): 8cf26cc

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +35 -6
  2. app.py +702 -0
  3. generate_data.py +283 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,41 @@
1
  ---
2
- title: Candlestick Diffusion
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Candlestick Chart Diffusion
3
+ emoji: 📈
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Candlestick Chart Diffusion Generator
14
+
15
+ Generate candlestick chart images from text descriptions using a conditional diffusion model.
16
+
17
+ ## Features
18
+
19
+ - 🎨 **Generate** candlestick charts from text prompts
20
+ - 🏋️ **Train** your own model on custom data
21
+ - 📂 **Load** pre-trained checkpoints
22
+
23
+ ## Example Prompts
24
+
25
+ - "bullish trend with high volatility"
26
+ - "bearish reversal pattern"
27
+ - "double bottom formation"
28
+ - "sideways market consolidation"
29
+ - "head and shoulders pattern"
30
+
31
+ ## How to Use
32
+
33
+ 1. **Generate Dataset**: Use the data generator script to create training data
34
+ 2. **Train Model**: Upload dataset and train for 50+ epochs
35
+ 3. **Generate Charts**: Enter a text prompt and generate!
36
+
37
+ ## Model Architecture
38
+
39
+ - **U-Net** with cross-attention for text conditioning
40
+ - **Diffusion** with cosine noise schedule
41
+ - **Text Encoder** with transformer layers
app.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Candlestick Chart Diffusion Model - Hugging Face Spaces App
3
+ Generates candlestick chart images from text prompts
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import gradio as gr
11
+ from PIL import Image
12
+ import numpy as np
13
+ from pathlib import Path
14
+ import math
15
+ from tqdm import tqdm
16
+ import json
17
+ import random
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms
20
+ from einops import rearrange
21
+
22
+ # ============== Model Components ==============
23
+
24
+ class SinusoidalPositionEmbeddings(nn.Module):
25
+ def __init__(self, dim):
26
+ super().__init__()
27
+ self.dim = dim
28
+
29
+ def forward(self, time):
30
+ device = time.device
31
+ half_dim = self.dim // 2
32
+ embeddings = math.log(10000) / (half_dim - 1)
33
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
34
+ embeddings = time[:, None] * embeddings[None, :]
35
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
36
+ return embeddings
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ def __init__(self, in_channels, out_channels, time_emb_dim, groups=8):
41
+ super().__init__()
42
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
43
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
44
+ self.norm1 = nn.GroupNorm(groups, in_channels)
45
+ self.norm2 = nn.GroupNorm(groups, out_channels)
46
+ self.time_mlp = nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.Linear(time_emb_dim, out_channels * 2)
49
+ )
50
+ self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
51
+
52
+ def forward(self, x, time_emb):
53
+ h = F.silu(self.norm1(x))
54
+ h = self.conv1(h)
55
+ time_emb = self.time_mlp(time_emb)
56
+ time_emb = rearrange(time_emb, "b c -> b c 1 1")
57
+ scale, shift = time_emb.chunk(2, dim=1)
58
+ h = h * (1 + scale) + shift
59
+ h = F.silu(self.norm2(h))
60
+ h = self.conv2(h)
61
+ return h + self.residual_conv(x)
62
+
63
+
64
+ class AttentionBlock(nn.Module):
65
+ def __init__(self, channels, num_heads=4):
66
+ super().__init__()
67
+ self.num_heads = num_heads
68
+ self.head_dim = channels // num_heads
69
+ self.norm = nn.GroupNorm(8, channels)
70
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
71
+ self.proj = nn.Conv2d(channels, channels, 1)
72
+ self.scale = self.head_dim ** -0.5
73
+
74
+ def forward(self, x):
75
+ b, c, h, w = x.shape
76
+ x_norm = self.norm(x)
77
+ qkv = self.qkv(x_norm)
78
+ q, k, v = qkv.chunk(3, dim=1)
79
+ q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads)
80
+ k = rearrange(k, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads)
81
+ v = rearrange(v, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads)
82
+ attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
83
+ attn = F.softmax(attn, dim=-1)
84
+ out = torch.einsum("bhij,bhjd->bhid", attn, v)
85
+ out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w)
86
+ return x + self.proj(out)
87
+
88
+
89
+ class CrossAttentionBlock(nn.Module):
90
+ def __init__(self, channels, context_dim, num_heads=4):
91
+ super().__init__()
92
+ self.num_heads = num_heads
93
+ self.head_dim = channels // num_heads
94
+ self.norm = nn.GroupNorm(8, channels)
95
+ self.norm_context = nn.LayerNorm(context_dim)
96
+ self.to_q = nn.Conv2d(channels, channels, 1)
97
+ self.to_k = nn.Linear(context_dim, channels)
98
+ self.to_v = nn.Linear(context_dim, channels)
99
+ self.proj = nn.Conv2d(channels, channels, 1)
100
+ self.scale = self.head_dim ** -0.5
101
+
102
+ def forward(self, x, context):
103
+ b, c, h, w = x.shape
104
+ x_norm = self.norm(x)
105
+ context = self.norm_context(context)
106
+ q = self.to_q(x_norm)
107
+ k = self.to_k(context)
108
+ v = self.to_v(context)
109
+ q = rearrange(q, "b (heads d) h w -> b heads (h w) d", heads=self.num_heads)
110
+ k = rearrange(k, "b n (heads d) -> b heads n d", heads=self.num_heads)
111
+ v = rearrange(v, "b n (heads d) -> b heads n d", heads=self.num_heads)
112
+ attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
113
+ attn = F.softmax(attn, dim=-1)
114
+ out = torch.einsum("bhij,bhjd->bhid", attn, v)
115
+ out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w)
116
+ return x + self.proj(out)
117
+
118
+
119
+ class DownBlock(nn.Module):
120
+ def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, downsample=True):
121
+ super().__init__()
122
+ self.res1 = ResidualBlock(in_ch, out_ch, time_dim)
123
+ self.res2 = ResidualBlock(out_ch, out_ch, time_dim)
124
+ self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity()
125
+ self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None
126
+ self.downsample = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1) if downsample else nn.Identity()
127
+
128
+ def forward(self, x, time_emb, context):
129
+ x = self.res1(x, time_emb)
130
+ x = self.res2(x, time_emb)
131
+ if not isinstance(self.attn, nn.Identity):
132
+ x = self.attn(x)
133
+ x = self.cross_attn(x, context)
134
+ skip = x
135
+ x = self.downsample(x)
136
+ return x, skip
137
+
138
+
139
+ class UpBlock(nn.Module):
140
+ def __init__(self, in_ch, out_ch, time_dim, context_dim, has_attn=True, upsample=True):
141
+ super().__init__()
142
+ self.res1 = ResidualBlock(in_ch + out_ch, out_ch, time_dim)
143
+ self.res2 = ResidualBlock(out_ch, out_ch, time_dim)
144
+ self.attn = AttentionBlock(out_ch) if has_attn else nn.Identity()
145
+ self.cross_attn = CrossAttentionBlock(out_ch, context_dim) if has_attn else None
146
+ self.upsample = nn.Sequential(
147
+ nn.Upsample(scale_factor=2, mode="nearest"),
148
+ nn.Conv2d(out_ch, out_ch, 3, padding=1)
149
+ ) if upsample else nn.Identity()
150
+
151
+ def forward(self, x, skip, time_emb, context):
152
+ x = torch.cat([x, skip], dim=1)
153
+ x = self.res1(x, time_emb)
154
+ x = self.res2(x, time_emb)
155
+ if not isinstance(self.attn, nn.Identity):
156
+ x = self.attn(x)
157
+ x = self.cross_attn(x, context)
158
+ x = self.upsample(x)
159
+ return x
160
+
161
+
162
+ class ConditionalUNet(nn.Module):
163
+ def __init__(self, in_ch=3, out_ch=3, base_ch=64, channel_mults=(1, 2, 4), context_dim=256):
164
+ super().__init__()
165
+ time_dim = base_ch * 4
166
+
167
+ self.time_mlp = nn.Sequential(
168
+ SinusoidalPositionEmbeddings(base_ch),
169
+ nn.Linear(base_ch, time_dim),
170
+ nn.SiLU(),
171
+ nn.Linear(time_dim, time_dim)
172
+ )
173
+
174
+ self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
175
+
176
+ # Downsampling
177
+ self.down_blocks = nn.ModuleList()
178
+ channels = [base_ch]
179
+ in_ch_block = base_ch
180
+
181
+ for i, mult in enumerate(channel_mults):
182
+ out_ch_block = base_ch * mult
183
+ is_last = i == len(channel_mults) - 1
184
+ has_attn = mult >= 2
185
+ self.down_blocks.append(
186
+ DownBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last)
187
+ )
188
+ channels.append(out_ch_block)
189
+ in_ch_block = out_ch_block
190
+
191
+ # Middle
192
+ self.mid_res1 = ResidualBlock(in_ch_block, in_ch_block, time_dim)
193
+ self.mid_attn = AttentionBlock(in_ch_block)
194
+ self.mid_cross = CrossAttentionBlock(in_ch_block, context_dim)
195
+ self.mid_res2 = ResidualBlock(in_ch_block, in_ch_block, time_dim)
196
+
197
+ # Upsampling
198
+ self.up_blocks = nn.ModuleList()
199
+ for i, mult in enumerate(reversed(channel_mults)):
200
+ out_ch_block = base_ch * mult
201
+ is_last = i == len(channel_mults) - 1
202
+ has_attn = mult >= 2
203
+ self.up_blocks.append(
204
+ UpBlock(in_ch_block, out_ch_block, time_dim, context_dim, has_attn, not is_last)
205
+ )
206
+ in_ch_block = out_ch_block
207
+
208
+ self.norm_out = nn.GroupNorm(8, base_ch)
209
+ self.conv_out = nn.Conv2d(base_ch, 3, 3, padding=1)
210
+ self.channels = channels
211
+
212
+ def forward(self, x, time, context):
213
+ t = self.time_mlp(time)
214
+ x = self.conv_in(x)
215
+
216
+ skips = []
217
+ for block in self.down_blocks:
218
+ x, skip = block(x, t, context)
219
+ skips.append(skip)
220
+
221
+ x = self.mid_res1(x, t)
222
+ x = self.mid_attn(x)
223
+ x = self.mid_cross(x, context)
224
+ x = self.mid_res2(x, t)
225
+
226
+ for block in self.up_blocks:
227
+ skip = skips.pop()
228
+ x = block(x, skip, t, context)
229
+
230
+ x = F.silu(self.norm_out(x))
231
+ return self.conv_out(x)
232
+
233
+
234
+ # ============== Text Encoder ==============
235
+
236
+ class SimpleTextEncoder(nn.Module):
237
+ def __init__(self, vocab_size=200, embed_dim=256, max_len=64):
238
+ super().__init__()
239
+ self.max_len = max_len
240
+ self.embed_dim = embed_dim
241
+ self.embed = nn.Embedding(vocab_size, embed_dim)
242
+ self.pos_embed = nn.Embedding(max_len, embed_dim)
243
+ self.transformer = nn.TransformerEncoder(
244
+ nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=512, batch_first=True),
245
+ num_layers=2
246
+ )
247
+ self.norm = nn.LayerNorm(embed_dim)
248
+
249
+ chars = " abcdefghijklmnopqrstuvwxyz0123456789-_.,;:!?()[]{}'\"/\\@#$%^&*+=<>~`"
250
+ self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)}
251
+ self.char_to_idx["<pad>"] = 0
252
+
253
+ def tokenize(self, texts, device):
254
+ batch = []
255
+ for text in texts:
256
+ text = text.lower()[:self.max_len]
257
+ tokens = [self.char_to_idx.get(c, 0) for c in text]
258
+ tokens += [0] * (self.max_len - len(tokens))
259
+ batch.append(tokens)
260
+ return torch.tensor(batch, device=device)
261
+
262
+ def forward(self, texts, device):
263
+ tokens = self.tokenize(texts, device)
264
+ pos = torch.arange(self.max_len, device=device).unsqueeze(0)
265
+ x = self.embed(tokens) + self.pos_embed(pos)
266
+ x = self.transformer(x)
267
+ return self.norm(x)
268
+
269
+ def get_uncond(self, batch_size, device):
270
+ return self.forward([""] * batch_size, device)
271
+
272
+
273
+ # ============== Diffusion ==============
274
+
275
+ class GaussianDiffusion:
276
+ def __init__(self, timesteps=1000, device="cuda"):
277
+ self.timesteps = timesteps
278
+ self.device = device
279
+
280
+ betas = self._cosine_schedule(timesteps)
281
+ alphas = 1 - betas
282
+ alpha_cum = torch.cumprod(alphas, dim=0)
283
+
284
+ self.betas = betas.to(device)
285
+ self.alphas = alphas.to(device)
286
+ self.alpha_cum = alpha_cum.to(device)
287
+ self.sqrt_alpha_cum = torch.sqrt(alpha_cum).to(device)
288
+ self.sqrt_one_minus_alpha_cum = torch.sqrt(1 - alpha_cum).to(device)
289
+
290
+ def _cosine_schedule(self, timesteps, s=0.008):
291
+ steps = timesteps + 1
292
+ x = torch.linspace(0, timesteps, steps)
293
+ alpha_cum = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
294
+ alpha_cum = alpha_cum / alpha_cum[0]
295
+ betas = 1 - (alpha_cum[1:] / alpha_cum[:-1])
296
+ return torch.clamp(betas, 0.0001, 0.999)
297
+
298
+ def add_noise(self, x, t, noise=None):
299
+ if noise is None:
300
+ noise = torch.randn_like(x)
301
+ sqrt_alpha = self.sqrt_alpha_cum[t].view(-1, 1, 1, 1)
302
+ sqrt_one_minus = self.sqrt_one_minus_alpha_cum[t].view(-1, 1, 1, 1)
303
+ return sqrt_alpha * x + sqrt_one_minus * noise, noise
304
+
305
+ def loss(self, model, x, context):
306
+ batch_size = x.shape[0]
307
+ t = torch.randint(0, self.timesteps, (batch_size,), device=self.device)
308
+ noise = torch.randn_like(x)
309
+ x_noisy, _ = self.add_noise(x, t, noise)
310
+ pred = model(x_noisy, t.float(), context)
311
+ return F.mse_loss(pred, noise)
312
+
313
+ @torch.no_grad()
314
+ def sample(self, model, context, context_uncond=None, shape=(1, 3, 128, 128),
315
+ steps=50, guidance_scale=7.5, progress_callback=None):
316
+ x = torch.randn(shape, device=self.device)
317
+ step_size = self.timesteps // steps
318
+ timesteps = list(range(0, self.timesteps, step_size))[::-1]
319
+
320
+ for i, t in enumerate(timesteps):
321
+ t_batch = torch.full((shape[0],), t, device=self.device, dtype=torch.long)
322
+
323
+ pred = model(x, t_batch.float(), context)
324
+
325
+ if guidance_scale > 1.0 and context_uncond is not None:
326
+ pred_uncond = model(x, t_batch.float(), context_uncond)
327
+ pred = pred_uncond + guidance_scale * (pred - pred_uncond)
328
+
329
+ alpha = self.alphas[t]
330
+ alpha_cum = self.alpha_cum[t]
331
+ beta = self.betas[t]
332
+
333
+ x = (1 / torch.sqrt(alpha)) * (x - (beta / self.sqrt_one_minus_alpha_cum[t]) * pred)
334
+
335
+ if t > 0:
336
+ noise = torch.randn_like(x)
337
+ x = x + torch.sqrt(beta) * noise
338
+
339
+ if progress_callback:
340
+ progress_callback((i + 1) / len(timesteps))
341
+
342
+ return x
343
+
344
+
345
+ # ============== Dataset ==============
346
+
347
+ class ChartDataset(Dataset):
348
+ def __init__(self, data_dir, image_size=128, split="train"):
349
+ self.data_dir = Path(data_dir)
350
+ self.image_size = image_size
351
+
352
+ with open(self.data_dir / "labels.json") as f:
353
+ self.labels = json.load(f)
354
+
355
+ all_files = sorted(list(self.labels.keys()))
356
+ split_idx = int(len(all_files) * 0.9)
357
+ self.files = all_files[:split_idx] if split == "train" else all_files[split_idx:]
358
+
359
+ self.transform = transforms.Compose([
360
+ transforms.Resize((image_size, image_size)),
361
+ transforms.ToTensor(),
362
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
363
+ ])
364
+
365
+ def __len__(self):
366
+ return len(self.files)
367
+
368
+ def __getitem__(self, idx):
369
+ filename = self.files[idx]
370
+ image = Image.open(self.data_dir / "images" / filename).convert("RGB")
371
+ image = self.transform(image)
372
+ text = self.labels[filename]
373
+ if random.random() < 0.1:
374
+ text = ""
375
+ return image, text
376
+
377
+
378
+ def collate_fn(batch):
379
+ images = torch.stack([b[0] for b in batch])
380
+ texts = [b[1] for b in batch]
381
+ return images, texts
382
+
383
+
384
+ # ============== Global State ==============
385
+
386
+ MODEL = None
387
+ TEXT_ENCODER = None
388
+ DIFFUSION = None
389
+ DEVICE = None
390
+ CONFIG = None
391
+
392
+
393
+ def load_model(checkpoint_path=None):
394
+ global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
395
+
396
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
+ print(f"Using device: {DEVICE}")
398
+
399
+ # Default config
400
+ CONFIG = {
401
+ "base_channels": 64,
402
+ "channel_mults": (1, 2, 4),
403
+ "context_dim": 256,
404
+ "image_size": 128,
405
+ "timesteps": 1000
406
+ }
407
+
408
+ # Load checkpoint if exists
409
+ if checkpoint_path and os.path.exists(checkpoint_path):
410
+ print(f"Loading checkpoint from {checkpoint_path}")
411
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
412
+ if "config" in checkpoint:
413
+ CONFIG.update(checkpoint["config"])
414
+
415
+ # Create models
416
+ TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE)
417
+ MODEL = ConditionalUNet(
418
+ base_ch=CONFIG["base_channels"],
419
+ channel_mults=CONFIG["channel_mults"],
420
+ context_dim=CONFIG["context_dim"]
421
+ ).to(DEVICE)
422
+
423
+ # Load weights if available
424
+ if checkpoint_path and os.path.exists(checkpoint_path):
425
+ MODEL.load_state_dict(checkpoint["model_state_dict"])
426
+ if "text_encoder_state_dict" in checkpoint:
427
+ TEXT_ENCODER.load_state_dict(checkpoint["text_encoder_state_dict"])
428
+ print("Model weights loaded!")
429
+
430
+ MODEL.eval()
431
+ DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE)
432
+
433
+ num_params = sum(p.numel() for p in MODEL.parameters())
434
+ print(f"Model parameters: {num_params:,}")
435
+
436
+ return True
437
+
438
+
439
+ # ============== Gradio Interface ==============
440
+
441
+ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress()):
442
+ global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
443
+
444
+ if MODEL is None:
445
+ return None, "❌ Model not loaded! Train first or load a checkpoint."
446
+
447
+ if not prompt.strip():
448
+ return None, "❌ Please enter a prompt!"
449
+
450
+ try:
451
+ if seed >= 0:
452
+ torch.manual_seed(seed)
453
+ if DEVICE.type == "cuda":
454
+ torch.cuda.manual_seed(seed)
455
+
456
+ def update_progress(p):
457
+ progress(p, desc="Generating...")
458
+
459
+ with torch.no_grad():
460
+ context = TEXT_ENCODER([prompt], DEVICE)
461
+ context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE)
462
+
463
+ samples = DIFFUSION.sample(
464
+ MODEL, context, context_uncond,
465
+ shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]),
466
+ steps=num_steps,
467
+ guidance_scale=guidance_scale,
468
+ progress_callback=update_progress
469
+ )
470
+
471
+ # Convert to image
472
+ samples = (samples + 1) / 2
473
+ samples = samples.clamp(0, 1)
474
+ samples = (samples * 255).to(torch.uint8)
475
+ img_array = samples[0].permute(1, 2, 0).cpu().numpy()
476
+ img = Image.fromarray(img_array)
477
+
478
+ return img, f"✅ Generated successfully!"
479
+
480
+ except Exception as e:
481
+ return None, f"❌ Error: {str(e)}"
482
+
483
+
484
+ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name, progress=gr.Progress()):
485
+ global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
486
+
487
+ try:
488
+ # Setup
489
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
490
+ CONFIG = {
491
+ "base_channels": 64,
492
+ "channel_mults": (1, 2, 4),
493
+ "context_dim": 256,
494
+ "image_size": image_size,
495
+ "timesteps": 1000
496
+ }
497
+
498
+ # Create models
499
+ TEXT_ENCODER = SimpleTextEncoder(embed_dim=CONFIG["context_dim"]).to(DEVICE)
500
+ MODEL = ConditionalUNet(
501
+ base_ch=CONFIG["base_channels"],
502
+ channel_mults=CONFIG["channel_mults"],
503
+ context_dim=CONFIG["context_dim"]
504
+ ).to(DEVICE)
505
+ DIFFUSION = GaussianDiffusion(timesteps=CONFIG["timesteps"], device=DEVICE)
506
+
507
+ num_params = sum(p.numel() for p in MODEL.parameters())
508
+
509
+ # Dataset
510
+ train_dataset = ChartDataset(data_path, image_size=image_size, split="train")
511
+ train_loader = DataLoader(
512
+ train_dataset, batch_size=batch_size, shuffle=True,
513
+ num_workers=2, pin_memory=True, drop_last=True, collate_fn=collate_fn
514
+ )
515
+
516
+ # Optimizer
517
+ optimizer = torch.optim.AdamW(
518
+ list(MODEL.parameters()) + list(TEXT_ENCODER.parameters()),
519
+ lr=learning_rate
520
+ )
521
+
522
+ # Training
523
+ MODEL.train()
524
+ TEXT_ENCODER.train()
525
+
526
+ logs = [f"🚀 Training started on {DEVICE}"]
527
+ logs.append(f"📊 Model parameters: {num_params:,}")
528
+ logs.append(f"📁 Training samples: {len(train_dataset)}")
529
+ logs.append("-" * 40)
530
+
531
+ total_steps = epochs * len(train_loader)
532
+ current_step = 0
533
+
534
+ for epoch in range(epochs):
535
+ epoch_loss = 0
536
+ for images, texts in train_loader:
537
+ images = images.to(DEVICE)
538
+ context = TEXT_ENCODER(texts, DEVICE)
539
+
540
+ optimizer.zero_grad()
541
+ loss = DIFFUSION.loss(MODEL, images, context)
542
+ loss.backward()
543
+ torch.nn.utils.clip_grad_norm_(MODEL.parameters(), 1.0)
544
+ optimizer.step()
545
+
546
+ epoch_loss += loss.item()
547
+ current_step += 1
548
+ progress(current_step / total_steps, desc=f"Epoch {epoch+1}/{epochs}")
549
+
550
+ avg_loss = epoch_loss / len(train_loader)
551
+ logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
552
+
553
+ # Save model
554
+ MODEL.eval()
555
+ os.makedirs("checkpoints", exist_ok=True)
556
+ save_path = f"checkpoints/{save_name}.pt"
557
+ torch.save({
558
+ "model_state_dict": MODEL.state_dict(),
559
+ "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
560
+ "config": CONFIG
561
+ }, save_path)
562
+
563
+ logs.append("-" * 40)
564
+ logs.append(f"✅ Model saved to {save_path}")
565
+
566
+ return "\n".join(logs)
567
+
568
+ except Exception as e:
569
+ return f"❌ Training failed: {str(e)}"
570
+
571
+
572
+ def load_checkpoint(checkpoint_file):
573
+ if checkpoint_file is None:
574
+ return "❌ No file selected"
575
+
576
+ try:
577
+ load_model(checkpoint_file.name)
578
+ return f"✅ Model loaded from {checkpoint_file.name}"
579
+ except Exception as e:
580
+ return f"❌ Failed to load: {str(e)}"
581
+
582
+
583
+ # ============== Gradio UI ==============
584
+
585
+ def create_demo():
586
+ with gr.Blocks(title="Candlestick Chart Generator", theme=gr.themes.Soft()) as demo:
587
+ gr.Markdown("""
588
+ # 📈 Candlestick Chart Diffusion Generator
589
+
590
+ Generate candlestick chart images from text descriptions using a diffusion model.
591
+
592
+ **Steps:**
593
+ 1. Upload your dataset (or use the generator script to create one)
594
+ 2. Train the model
595
+ 3. Generate charts from text prompts!
596
+ """)
597
+
598
+ with gr.Tabs():
599
+ # Generation Tab
600
+ with gr.TabItem("🎨 Generate"):
601
+ with gr.Row():
602
+ with gr.Column(scale=1):
603
+ prompt_input = gr.Textbox(
604
+ label="Prompt",
605
+ placeholder="e.g., bullish trend with high volatility",
606
+ lines=2
607
+ )
608
+
609
+ with gr.Row():
610
+ num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps")
611
+ guidance = gr.Slider(1, 20, value=7.5, step=0.5, label="Guidance Scale")
612
+
613
+ seed_input = gr.Number(label="Seed (-1 for random)", value=-1)
614
+ generate_btn = gr.Button("🎨 Generate", variant="primary")
615
+ gen_status = gr.Textbox(label="Status", interactive=False)
616
+
617
+ gr.Markdown("### Example Prompts")
618
+ gr.Examples(
619
+ examples=[
620
+ ["bullish trend with high volatility"],
621
+ ["bearish reversal pattern"],
622
+ ["double bottom formation low volatility"],
623
+ ["sideways market consolidation"],
624
+ ["head and shoulders pattern"],
625
+ ["strong upward trend green candles"],
626
+ ],
627
+ inputs=[prompt_input]
628
+ )
629
+
630
+ with gr.Column(scale=1):
631
+ output_image = gr.Image(label="Generated Chart", type="pil")
632
+
633
+ generate_btn.click(
634
+ generate_chart,
635
+ inputs=[prompt_input, num_steps, guidance, seed_input],
636
+ outputs=[output_image, gen_status]
637
+ )
638
+
639
+ # Training Tab
640
+ with gr.TabItem("🏋️ Train"):
641
+ gr.Markdown("""
642
+ ### Training Configuration
643
+
644
+ Upload your dataset folder containing:
645
+ - `images/` folder with chart images
646
+ - `labels.json` with text descriptions
647
+ """)
648
+
649
+ with gr.Row():
650
+ with gr.Column():
651
+ data_path = gr.Textbox(label="Dataset Path", value="./dataset")
652
+ epochs = gr.Slider(1, 200, value=50, step=1, label="Epochs")
653
+ batch_size = gr.Slider(1, 64, value=16, step=1, label="Batch Size")
654
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4)
655
+ image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size")
656
+ save_name = gr.Textbox(label="Model Name", value="candlestick_model")
657
+
658
+ train_btn = gr.Button("🚀 Start Training", variant="primary")
659
+
660
+ with gr.Column():
661
+ train_logs = gr.Textbox(label="Training Logs", lines=20, interactive=False)
662
+
663
+ train_btn.click(
664
+ train_model,
665
+ inputs=[data_path, epochs, batch_size, learning_rate, image_size, save_name],
666
+ outputs=[train_logs]
667
+ )
668
+
669
+ # Load Model Tab
670
+ with gr.TabItem("📂 Load Model"):
671
+ gr.Markdown("### Load a trained checkpoint")
672
+
673
+ checkpoint_upload = gr.File(label="Upload Checkpoint (.pt file)")
674
+ load_btn = gr.Button("Load Model")
675
+ load_status = gr.Textbox(label="Status", interactive=False)
676
+
677
+ load_btn.click(
678
+ load_checkpoint,
679
+ inputs=[checkpoint_upload],
680
+ outputs=[load_status]
681
+ )
682
+
683
+ gr.Markdown("""
684
+ ---
685
+ ### Tips
686
+ - **Training**: Use at least 5000 samples and 50+ epochs for good results
687
+ - **Guidance Scale**: Higher values (7-12) follow prompts more closely
688
+ - **Steps**: 50 steps is a good balance between speed and quality
689
+ """)
690
+
691
+ return demo
692
+
693
+
694
+ # ============== Main ==============
695
+
696
+ if __name__ == "__main__":
697
+ # Try to load existing checkpoint
698
+ if os.path.exists("checkpoints/candlestick_model.pt"):
699
+ load_model("checkpoints/candlestick_model.pt")
700
+
701
+ demo = create_demo()
702
+ demo.launch()
generate_data.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Generator for Candlestick Charts
3
+ Run this to create training data before training the model.
4
+
5
+ Usage:
6
+ python generate_data.py --num_samples 10000 --output_dir ./dataset
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import random
12
+ import argparse
13
+ import numpy as np
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.patches import Rectangle
18
+ import io
19
+
20
+
21
+ class CandlestickGenerator:
22
+ def __init__(self, image_size=(128, 128), num_candles=20):
23
+ self.image_size = image_size
24
+ self.num_candles = num_candles
25
+ self.bg_color = "#1a1a2e"
26
+ self.bullish_color = "#00ff88"
27
+ self.bearish_color = "#ff4466"
28
+
29
+ self.patterns = {
30
+ "bullish_trend": self._bullish_trend,
31
+ "bearish_trend": self._bearish_trend,
32
+ "sideways": self._sideways,
33
+ "volatile": self._volatile,
34
+ "bullish_reversal": self._bullish_reversal,
35
+ "bearish_reversal": self._bearish_reversal,
36
+ "double_top": self._double_top,
37
+ "double_bottom": self._double_bottom,
38
+ }
39
+
40
+ def _bullish_trend(self, n, vol):
41
+ candles = []
42
+ price = 100
43
+ for i in range(n):
44
+ trend = random.uniform(0.5, 2.0)
45
+ noise = random.gauss(0, vol)
46
+ o = price + noise
47
+ c = o + random.uniform(0, vol * 2) + trend
48
+ if random.random() < 0.7:
49
+ c = max(c, o + 0.5)
50
+ h = max(o, c) + random.uniform(0, vol)
51
+ l = min(o, c) - random.uniform(0, vol)
52
+ candles.append({"o": o, "h": h, "l": l, "c": c})
53
+ price = c
54
+ return candles
55
+
56
+ def _bearish_trend(self, n, vol):
57
+ candles = []
58
+ price = 150
59
+ for i in range(n):
60
+ trend = random.uniform(0.5, 2.0)
61
+ noise = random.gauss(0, vol)
62
+ o = price + noise
63
+ c = o - random.uniform(0, vol * 2) - trend
64
+ if random.random() < 0.7:
65
+ c = min(c, o - 0.5)
66
+ h = max(o, c) + random.uniform(0, vol)
67
+ l = min(o, c) - random.uniform(0, vol)
68
+ candles.append({"o": o, "h": h, "l": l, "c": c})
69
+ price = c
70
+ return candles
71
+
72
+ def _sideways(self, n, vol):
73
+ candles = []
74
+ base = 100
75
+ for i in range(n):
76
+ center = base + random.gauss(0, vol * 2)
77
+ o = center + random.gauss(0, vol)
78
+ c = center + random.gauss(0, vol)
79
+ h = max(o, c) + random.uniform(0, vol)
80
+ l = min(o, c) - random.uniform(0, vol)
81
+ candles.append({"o": o, "h": h, "l": l, "c": c})
82
+ return candles
83
+
84
+ def _volatile(self, n, vol):
85
+ candles = []
86
+ price = 100
87
+ high_vol = vol * 3
88
+ for i in range(n):
89
+ direction = 1 if random.random() > 0.5 else -1
90
+ move = random.uniform(high_vol, high_vol * 2) * direction
91
+ o = price + random.gauss(0, high_vol)
92
+ c = o + move
93
+ h = max(o, c) + random.uniform(high_vol * 0.5, high_vol)
94
+ l = min(o, c) - random.uniform(high_vol * 0.5, high_vol)
95
+ candles.append({"o": o, "h": h, "l": l, "c": c})
96
+ price = c
97
+ return candles
98
+
99
+ def _bullish_reversal(self, n, vol):
100
+ mid = n // 2
101
+ part1 = self._bearish_trend(mid, vol)
102
+ last = part1[-1]["c"]
103
+ part2 = []
104
+ price = last
105
+ for i in range(n - mid):
106
+ trend = random.uniform(0.5, 1.5)
107
+ o = price + random.gauss(0, vol)
108
+ c = o + random.uniform(0, vol * 2) + trend
109
+ h = max(o, c) + random.uniform(0, vol)
110
+ l = min(o, c) - random.uniform(0, vol)
111
+ part2.append({"o": o, "h": h, "l": l, "c": c})
112
+ price = c
113
+ return part1 + part2
114
+
115
+ def _bearish_reversal(self, n, vol):
116
+ mid = n // 2
117
+ part1 = self._bullish_trend(mid, vol)
118
+ last = part1[-1]["c"]
119
+ part2 = []
120
+ price = last
121
+ for i in range(n - mid):
122
+ trend = random.uniform(0.5, 1.5)
123
+ o = price + random.gauss(0, vol)
124
+ c = o - random.uniform(0, vol * 2) - trend
125
+ h = max(o, c) + random.uniform(0, vol)
126
+ l = min(o, c) - random.uniform(0, vol)
127
+ part2.append({"o": o, "h": h, "l": l, "c": c})
128
+ price = c
129
+ return part1 + part2
130
+
131
+ def _double_top(self, n, vol):
132
+ third = n // 3
133
+ candles = []
134
+ base, peak = 100, 120
135
+
136
+ for i in range(third):
137
+ p = base + (peak - base) * (i / third) + random.gauss(0, vol)
138
+ o, c = p, p + random.uniform(-vol, vol)
139
+ h = max(o, c) + random.uniform(0, vol)
140
+ l = min(o, c) - random.uniform(0, vol)
141
+ candles.append({"o": o, "h": h, "l": l, "c": c})
142
+
143
+ for i in range(third):
144
+ p = peak - (peak - base) * 0.5 * (i / third) + random.gauss(0, vol)
145
+ o, c = p, p + random.uniform(-vol, vol)
146
+ h = max(o, c) + random.uniform(0, vol)
147
+ l = min(o, c) - random.uniform(0, vol)
148
+ candles.append({"o": o, "h": h, "l": l, "c": c})
149
+
150
+ for i in range(n - 2 * third):
151
+ prog = i / (n - 2 * third)
152
+ if prog < 0.5:
153
+ p = (base + peak) / 2 + (peak - (base + peak) / 2) * (prog * 2)
154
+ else:
155
+ p = peak - (peak - base) * ((prog - 0.5) * 2)
156
+ p += random.gauss(0, vol)
157
+ o, c = p, p + random.uniform(-vol, vol)
158
+ h = max(o, c) + random.uniform(0, vol)
159
+ l = min(o, c) - random.uniform(0, vol)
160
+ candles.append({"o": o, "h": h, "l": l, "c": c})
161
+
162
+ return candles
163
+
164
+ def _double_bottom(self, n, vol):
165
+ third = n // 3
166
+ candles = []
167
+ base, bottom = 120, 100
168
+
169
+ for i in range(third):
170
+ p = base - (base - bottom) * (i / third) + random.gauss(0, vol)
171
+ o, c = p, p + random.uniform(-vol, vol)
172
+ h = max(o, c) + random.uniform(0, vol)
173
+ l = min(o, c) - random.uniform(0, vol)
174
+ candles.append({"o": o, "h": h, "l": l, "c": c})
175
+
176
+ for i in range(third):
177
+ p = bottom + (base - bottom) * 0.5 * (i / third) + random.gauss(0, vol)
178
+ o, c = p, p + random.uniform(-vol, vol)
179
+ h = max(o, c) + random.uniform(0, vol)
180
+ l = min(o, c) - random.uniform(0, vol)
181
+ candles.append({"o": o, "h": h, "l": l, "c": c})
182
+
183
+ for i in range(n - 2 * third):
184
+ prog = i / (n - 2 * third)
185
+ if prog < 0.5:
186
+ p = (base + bottom) / 2 - ((base + bottom) / 2 - bottom) * (prog * 2)
187
+ else:
188
+ p = bottom + (base - bottom) * ((prog - 0.5) * 2)
189
+ p += random.gauss(0, vol)
190
+ o, c = p, p + random.uniform(-vol, vol)
191
+ h = max(o, c) + random.uniform(0, vol)
192
+ l = min(o, c) - random.uniform(0, vol)
193
+ candles.append({"o": o, "h": h, "l": l, "c": c})
194
+
195
+ return candles
196
+
197
+ def render(self, candles):
198
+ fig, ax = plt.subplots(figsize=(self.image_size[0]/100, self.image_size[1]/100), dpi=100)
199
+ fig.patch.set_facecolor(self.bg_color)
200
+ ax.set_facecolor(self.bg_color)
201
+
202
+ highs = [c["h"] for c in candles]
203
+ lows = [c["l"] for c in candles]
204
+ price_min = min(lows) * 0.98
205
+ price_max = max(highs) * 1.02
206
+
207
+ width = 0.6
208
+ for i, c in enumerate(candles):
209
+ color = self.bullish_color if c["c"] >= c["o"] else self.bearish_color
210
+ ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
211
+ body_bottom = min(c["o"], c["c"])
212
+ body_height = abs(c["c"] - c["o"]) or 0.1
213
+ rect = Rectangle((i - width/2, body_bottom), width, body_height,
214
+ facecolor=color, edgecolor=color)
215
+ ax.add_patch(rect)
216
+
217
+ ax.set_xlim(-1, len(candles))
218
+ ax.set_ylim(price_min, price_max)
219
+ ax.axis("off")
220
+
221
+ buf = io.BytesIO()
222
+ plt.savefig(buf, format="png", facecolor=self.bg_color,
223
+ bbox_inches="tight", pad_inches=0.1)
224
+ plt.close(fig)
225
+ buf.seek(0)
226
+
227
+ img = Image.open(buf).convert("RGB")
228
+ img = img.resize(self.image_size, Image.Resampling.LANCZOS)
229
+ return img
230
+
231
+ def generate_sample(self):
232
+ pattern = random.choice(list(self.patterns.keys()))
233
+ vol_name = random.choice(["low", "medium", "high"])
234
+ vol_map = {"low": 1.0, "medium": 3.0, "high": 6.0}
235
+
236
+ candles = self.patterns[pattern](self.num_candles, vol_map[vol_name])
237
+ image = self.render(candles)
238
+
239
+ descriptions = {
240
+ "bullish_trend": [f"bullish trend {vol_name} volatility", f"upward trending market {vol_name} movement", "strong buying pressure"],
241
+ "bearish_trend": [f"bearish trend {vol_name} volatility", f"downward trending market {vol_name} movement", "strong selling pressure"],
242
+ "sideways": [f"sideways market {vol_name} volatility", "range-bound trading", "consolidation pattern"],
243
+ "volatile": ["highly volatile market", "erratic price movement", "choppy market conditions"],
244
+ "bullish_reversal": [f"bullish reversal {vol_name} volatility", "v-shaped recovery", "trend change bearish to bullish"],
245
+ "bearish_reversal": [f"bearish reversal {vol_name} volatility", "inverted v pattern", "trend change bullish to bearish"],
246
+ "double_top": [f"double top pattern {vol_name} volatility", "m-shaped reversal", "bearish double top"],
247
+ "double_bottom": [f"double bottom pattern {vol_name} volatility", "w-shaped reversal", "bullish double bottom"],
248
+ }
249
+
250
+ description = random.choice(descriptions[pattern])
251
+ return image, description
252
+
253
+
254
+ def generate_dataset(output_dir, num_samples=10000, image_size=128):
255
+ os.makedirs(output_dir, exist_ok=True)
256
+ os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
257
+
258
+ generator = CandlestickGenerator(image_size=(image_size, image_size))
259
+ labels = {}
260
+
261
+ print(f"Generating {num_samples} samples...")
262
+ for i in tqdm(range(num_samples)):
263
+ image, description = generator.generate_sample()
264
+ filename = f"chart_{i:06d}.png"
265
+ image.save(os.path.join(output_dir, "images", filename))
266
+ labels[filename] = description
267
+
268
+ with open(os.path.join(output_dir, "labels.json"), "w") as f:
269
+ json.dump(labels, f, indent=2)
270
+
271
+ print(f"✅ Dataset saved to {output_dir}")
272
+ print(f" - {num_samples} images")
273
+ print(f" - Labels in labels.json")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = argparse.ArgumentParser()
278
+ parser.add_argument("--num_samples", type=int, default=10000)
279
+ parser.add_argument("--output_dir", type=str, default="./dataset")
280
+ parser.add_argument("--image_size", type=int, default=128)
281
+ args = parser.parse_args()
282
+
283
+ generate_dataset(args.output_dir, args.num_samples, args.image_size)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ Pillow>=9.5.0
5
+ numpy>=1.24.0
6
+ einops>=0.6.1
7
+ tqdm>=4.65.0
8
+ matplotlib>=3.7.0