candlestick-diffusion / generate_data.py
Spanicin's picture
Upload 4 files
d3932f4 verified
"""
Dataset Generator for Candlestick Charts
Run this to create training data before training the model.
Usage:
python generate_data.py --num_samples 10000 --output_dir ./dataset
"""
import os
import json
import random
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import io
class CandlestickGenerator:
def __init__(self, image_size=(128, 128), num_candles=20):
self.image_size = image_size
self.num_candles = num_candles
self.bg_color = "#1a1a2e"
self.bullish_color = "#00ff88"
self.bearish_color = "#ff4466"
self.patterns = {
"bullish_trend": self._bullish_trend,
"bearish_trend": self._bearish_trend,
"sideways": self._sideways,
"volatile": self._volatile,
"bullish_reversal": self._bullish_reversal,
"bearish_reversal": self._bearish_reversal,
"double_top": self._double_top,
"double_bottom": self._double_bottom,
}
def _bullish_trend(self, n, vol):
candles = []
price = 100
for i in range(n):
trend = random.uniform(0.5, 2.0)
noise = random.gauss(0, vol)
o = price + noise
c = o + random.uniform(0, vol * 2) + trend
if random.random() < 0.7:
c = max(c, o + 0.5)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
price = c
return candles
def _bearish_trend(self, n, vol):
candles = []
price = 150
for i in range(n):
trend = random.uniform(0.5, 2.0)
noise = random.gauss(0, vol)
o = price + noise
c = o - random.uniform(0, vol * 2) - trend
if random.random() < 0.7:
c = min(c, o - 0.5)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
price = c
return candles
def _sideways(self, n, vol):
candles = []
base = 100
for i in range(n):
center = base + random.gauss(0, vol * 2)
o = center + random.gauss(0, vol)
c = center + random.gauss(0, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
return candles
def _volatile(self, n, vol):
candles = []
price = 100
high_vol = vol * 3
for i in range(n):
direction = 1 if random.random() > 0.5 else -1
move = random.uniform(high_vol, high_vol * 2) * direction
o = price + random.gauss(0, high_vol)
c = o + move
h = max(o, c) + random.uniform(high_vol * 0.5, high_vol)
l = min(o, c) - random.uniform(high_vol * 0.5, high_vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
price = c
return candles
def _bullish_reversal(self, n, vol):
mid = n // 2
part1 = self._bearish_trend(mid, vol)
last = part1[-1]["c"]
part2 = []
price = last
for i in range(n - mid):
trend = random.uniform(0.5, 1.5)
o = price + random.gauss(0, vol)
c = o + random.uniform(0, vol * 2) + trend
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
part2.append({"o": o, "h": h, "l": l, "c": c})
price = c
return part1 + part2
def _bearish_reversal(self, n, vol):
mid = n // 2
part1 = self._bullish_trend(mid, vol)
last = part1[-1]["c"]
part2 = []
price = last
for i in range(n - mid):
trend = random.uniform(0.5, 1.5)
o = price + random.gauss(0, vol)
c = o - random.uniform(0, vol * 2) - trend
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
part2.append({"o": o, "h": h, "l": l, "c": c})
price = c
return part1 + part2
def _double_top(self, n, vol):
third = n // 3
candles = []
base, peak = 100, 120
for i in range(third):
p = base + (peak - base) * (i / third) + random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
for i in range(third):
p = peak - (peak - base) * 0.5 * (i / third) + random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
for i in range(n - 2 * third):
prog = i / (n - 2 * third)
if prog < 0.5:
p = (base + peak) / 2 + (peak - (base + peak) / 2) * (prog * 2)
else:
p = peak - (peak - base) * ((prog - 0.5) * 2)
p += random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
return candles
def _double_bottom(self, n, vol):
third = n // 3
candles = []
base, bottom = 120, 100
for i in range(third):
p = base - (base - bottom) * (i / third) + random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
for i in range(third):
p = bottom + (base - bottom) * 0.5 * (i / third) + random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
for i in range(n - 2 * third):
prog = i / (n - 2 * third)
if prog < 0.5:
p = (base + bottom) / 2 - ((base + bottom) / 2 - bottom) * (prog * 2)
else:
p = bottom + (base - bottom) * ((prog - 0.5) * 2)
p += random.gauss(0, vol)
o, c = p, p + random.uniform(-vol, vol)
h = max(o, c) + random.uniform(0, vol)
l = min(o, c) - random.uniform(0, vol)
candles.append({"o": o, "h": h, "l": l, "c": c})
return candles
def render(self, candles):
fig, ax = plt.subplots(figsize=(self.image_size[0]/100, self.image_size[1]/100), dpi=100)
fig.patch.set_facecolor(self.bg_color)
ax.set_facecolor(self.bg_color)
highs = [c["h"] for c in candles]
lows = [c["l"] for c in candles]
price_min = min(lows) * 0.98
price_max = max(highs) * 1.02
width = 0.6
for i, c in enumerate(candles):
color = self.bullish_color if c["c"] >= c["o"] else self.bearish_color
ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
body_bottom = min(c["o"], c["c"])
body_height = abs(c["c"] - c["o"]) or 0.1
rect = Rectangle((i - width/2, body_bottom), width, body_height,
facecolor=color, edgecolor=color)
ax.add_patch(rect)
ax.set_xlim(-1, len(candles))
ax.set_ylim(price_min, price_max)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", facecolor=self.bg_color,
bbox_inches="tight", pad_inches=0.1)
plt.close(fig)
buf.seek(0)
img = Image.open(buf).convert("RGB")
img = img.resize(self.image_size, Image.Resampling.LANCZOS)
return img
def generate_sample(self):
pattern = random.choice(list(self.patterns.keys()))
vol_name = random.choice(["low", "medium", "high"])
vol_map = {"low": 1.0, "medium": 3.0, "high": 6.0}
candles = self.patterns[pattern](self.num_candles, vol_map[vol_name])
image = self.render(candles)
descriptions = {
"bullish_trend": [f"bullish trend {vol_name} volatility", f"upward trending market {vol_name} movement", "strong buying pressure"],
"bearish_trend": [f"bearish trend {vol_name} volatility", f"downward trending market {vol_name} movement", "strong selling pressure"],
"sideways": [f"sideways market {vol_name} volatility", "range-bound trading", "consolidation pattern"],
"volatile": ["highly volatile market", "erratic price movement", "choppy market conditions"],
"bullish_reversal": [f"bullish reversal {vol_name} volatility", "v-shaped recovery", "trend change bearish to bullish"],
"bearish_reversal": [f"bearish reversal {vol_name} volatility", "inverted v pattern", "trend change bullish to bearish"],
"double_top": [f"double top pattern {vol_name} volatility", "m-shaped reversal", "bearish double top"],
"double_bottom": [f"double bottom pattern {vol_name} volatility", "w-shaped reversal", "bullish double bottom"],
}
description = random.choice(descriptions[pattern])
return image, description
def generate_dataset(output_dir, num_samples=10000, image_size=128):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
generator = CandlestickGenerator(image_size=(image_size, image_size))
labels = {}
print(f"Generating {num_samples} samples...")
for i in tqdm(range(num_samples)):
image, description = generator.generate_sample()
filename = f"chart_{i:06d}.png"
image.save(os.path.join(output_dir, "images", filename))
labels[filename] = description
with open(os.path.join(output_dir, "labels.json"), "w") as f:
json.dump(labels, f, indent=2)
print(f"✅ Dataset saved to {output_dir}")
print(f" - {num_samples} images")
print(f" - Labels in labels.json")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_samples", type=int, default=10000)
parser.add_argument("--output_dir", type=str, default="./dataset")
parser.add_argument("--image_size", type=int, default=128)
args = parser.parse_args()
generate_dataset(args.output_dir, args.num_samples, args.image_size)