""" Dataset Generator for Candlestick Charts Run this to create training data before training the model. Usage: python generate_data.py --num_samples 10000 --output_dir ./dataset """ import os import json import random import argparse import numpy as np from PIL import Image from tqdm import tqdm import matplotlib.pyplot as plt from matplotlib.patches import Rectangle import io class CandlestickGenerator: def __init__(self, image_size=(128, 128), num_candles=20): self.image_size = image_size self.num_candles = num_candles self.bg_color = "#1a1a2e" self.bullish_color = "#00ff88" self.bearish_color = "#ff4466" self.patterns = { "bullish_trend": self._bullish_trend, "bearish_trend": self._bearish_trend, "sideways": self._sideways, "volatile": self._volatile, "bullish_reversal": self._bullish_reversal, "bearish_reversal": self._bearish_reversal, "double_top": self._double_top, "double_bottom": self._double_bottom, } def _bullish_trend(self, n, vol): candles = [] price = 100 for i in range(n): trend = random.uniform(0.5, 2.0) noise = random.gauss(0, vol) o = price + noise c = o + random.uniform(0, vol * 2) + trend if random.random() < 0.7: c = max(c, o + 0.5) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) price = c return candles def _bearish_trend(self, n, vol): candles = [] price = 150 for i in range(n): trend = random.uniform(0.5, 2.0) noise = random.gauss(0, vol) o = price + noise c = o - random.uniform(0, vol * 2) - trend if random.random() < 0.7: c = min(c, o - 0.5) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) price = c return candles def _sideways(self, n, vol): candles = [] base = 100 for i in range(n): center = base + random.gauss(0, vol * 2) o = center + random.gauss(0, vol) c = center + random.gauss(0, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) return candles def _volatile(self, n, vol): candles = [] price = 100 high_vol = vol * 3 for i in range(n): direction = 1 if random.random() > 0.5 else -1 move = random.uniform(high_vol, high_vol * 2) * direction o = price + random.gauss(0, high_vol) c = o + move h = max(o, c) + random.uniform(high_vol * 0.5, high_vol) l = min(o, c) - random.uniform(high_vol * 0.5, high_vol) candles.append({"o": o, "h": h, "l": l, "c": c}) price = c return candles def _bullish_reversal(self, n, vol): mid = n // 2 part1 = self._bearish_trend(mid, vol) last = part1[-1]["c"] part2 = [] price = last for i in range(n - mid): trend = random.uniform(0.5, 1.5) o = price + random.gauss(0, vol) c = o + random.uniform(0, vol * 2) + trend h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) part2.append({"o": o, "h": h, "l": l, "c": c}) price = c return part1 + part2 def _bearish_reversal(self, n, vol): mid = n // 2 part1 = self._bullish_trend(mid, vol) last = part1[-1]["c"] part2 = [] price = last for i in range(n - mid): trend = random.uniform(0.5, 1.5) o = price + random.gauss(0, vol) c = o - random.uniform(0, vol * 2) - trend h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) part2.append({"o": o, "h": h, "l": l, "c": c}) price = c return part1 + part2 def _double_top(self, n, vol): third = n // 3 candles = [] base, peak = 100, 120 for i in range(third): p = base + (peak - base) * (i / third) + random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) for i in range(third): p = peak - (peak - base) * 0.5 * (i / third) + random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) for i in range(n - 2 * third): prog = i / (n - 2 * third) if prog < 0.5: p = (base + peak) / 2 + (peak - (base + peak) / 2) * (prog * 2) else: p = peak - (peak - base) * ((prog - 0.5) * 2) p += random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) return candles def _double_bottom(self, n, vol): third = n // 3 candles = [] base, bottom = 120, 100 for i in range(third): p = base - (base - bottom) * (i / third) + random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) for i in range(third): p = bottom + (base - bottom) * 0.5 * (i / third) + random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) for i in range(n - 2 * third): prog = i / (n - 2 * third) if prog < 0.5: p = (base + bottom) / 2 - ((base + bottom) / 2 - bottom) * (prog * 2) else: p = bottom + (base - bottom) * ((prog - 0.5) * 2) p += random.gauss(0, vol) o, c = p, p + random.uniform(-vol, vol) h = max(o, c) + random.uniform(0, vol) l = min(o, c) - random.uniform(0, vol) candles.append({"o": o, "h": h, "l": l, "c": c}) return candles def render(self, candles): fig, ax = plt.subplots(figsize=(self.image_size[0]/100, self.image_size[1]/100), dpi=100) fig.patch.set_facecolor(self.bg_color) ax.set_facecolor(self.bg_color) highs = [c["h"] for c in candles] lows = [c["l"] for c in candles] price_min = min(lows) * 0.98 price_max = max(highs) * 1.02 width = 0.6 for i, c in enumerate(candles): color = self.bullish_color if c["c"] >= c["o"] else self.bearish_color ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1) body_bottom = min(c["o"], c["c"]) body_height = abs(c["c"] - c["o"]) or 0.1 rect = Rectangle((i - width/2, body_bottom), width, body_height, facecolor=color, edgecolor=color) ax.add_patch(rect) ax.set_xlim(-1, len(candles)) ax.set_ylim(price_min, price_max) ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png", facecolor=self.bg_color, bbox_inches="tight", pad_inches=0.1) plt.close(fig) buf.seek(0) img = Image.open(buf).convert("RGB") img = img.resize(self.image_size, Image.Resampling.LANCZOS) return img def generate_sample(self): pattern = random.choice(list(self.patterns.keys())) vol_name = random.choice(["low", "medium", "high"]) vol_map = {"low": 1.0, "medium": 3.0, "high": 6.0} candles = self.patterns[pattern](self.num_candles, vol_map[vol_name]) image = self.render(candles) descriptions = { "bullish_trend": [f"bullish trend {vol_name} volatility", f"upward trending market {vol_name} movement", "strong buying pressure"], "bearish_trend": [f"bearish trend {vol_name} volatility", f"downward trending market {vol_name} movement", "strong selling pressure"], "sideways": [f"sideways market {vol_name} volatility", "range-bound trading", "consolidation pattern"], "volatile": ["highly volatile market", "erratic price movement", "choppy market conditions"], "bullish_reversal": [f"bullish reversal {vol_name} volatility", "v-shaped recovery", "trend change bearish to bullish"], "bearish_reversal": [f"bearish reversal {vol_name} volatility", "inverted v pattern", "trend change bullish to bearish"], "double_top": [f"double top pattern {vol_name} volatility", "m-shaped reversal", "bearish double top"], "double_bottom": [f"double bottom pattern {vol_name} volatility", "w-shaped reversal", "bullish double bottom"], } description = random.choice(descriptions[pattern]) return image, description def generate_dataset(output_dir, num_samples=10000, image_size=128): os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, "images"), exist_ok=True) generator = CandlestickGenerator(image_size=(image_size, image_size)) labels = {} print(f"Generating {num_samples} samples...") for i in tqdm(range(num_samples)): image, description = generator.generate_sample() filename = f"chart_{i:06d}.png" image.save(os.path.join(output_dir, "images", filename)) labels[filename] = description with open(os.path.join(output_dir, "labels.json"), "w") as f: json.dump(labels, f, indent=2) print(f"✅ Dataset saved to {output_dir}") print(f" - {num_samples} images") print(f" - Labels in labels.json") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num_samples", type=int, default=10000) parser.add_argument("--output_dir", type=str, default="./dataset") parser.add_argument("--image_size", type=int, default=128) args = parser.parse_args() generate_dataset(args.output_dir, args.num_samples, args.image_size)