Spaces:
Runtime error
Runtime error
| """ | |
| Dataset Generator for Candlestick Charts | |
| Run this to create training data before training the model. | |
| Usage: | |
| python generate_data.py --num_samples 10000 --output_dir ./dataset | |
| """ | |
| import os | |
| import json | |
| import random | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| import io | |
| class CandlestickGenerator: | |
| def __init__(self, image_size=(128, 128), num_candles=20): | |
| self.image_size = image_size | |
| self.num_candles = num_candles | |
| self.bg_color = "#1a1a2e" | |
| self.bullish_color = "#00ff88" | |
| self.bearish_color = "#ff4466" | |
| self.patterns = { | |
| "bullish_trend": self._bullish_trend, | |
| "bearish_trend": self._bearish_trend, | |
| "sideways": self._sideways, | |
| "volatile": self._volatile, | |
| "bullish_reversal": self._bullish_reversal, | |
| "bearish_reversal": self._bearish_reversal, | |
| "double_top": self._double_top, | |
| "double_bottom": self._double_bottom, | |
| } | |
| def _bullish_trend(self, n, vol): | |
| candles = [] | |
| price = 100 | |
| for i in range(n): | |
| trend = random.uniform(0.5, 2.0) | |
| noise = random.gauss(0, vol) | |
| o = price + noise | |
| c = o + random.uniform(0, vol * 2) + trend | |
| if random.random() < 0.7: | |
| c = max(c, o + 0.5) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return candles | |
| def _bearish_trend(self, n, vol): | |
| candles = [] | |
| price = 150 | |
| for i in range(n): | |
| trend = random.uniform(0.5, 2.0) | |
| noise = random.gauss(0, vol) | |
| o = price + noise | |
| c = o - random.uniform(0, vol * 2) - trend | |
| if random.random() < 0.7: | |
| c = min(c, o - 0.5) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return candles | |
| def _sideways(self, n, vol): | |
| candles = [] | |
| base = 100 | |
| for i in range(n): | |
| center = base + random.gauss(0, vol * 2) | |
| o = center + random.gauss(0, vol) | |
| c = center + random.gauss(0, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| return candles | |
| def _volatile(self, n, vol): | |
| candles = [] | |
| price = 100 | |
| high_vol = vol * 3 | |
| for i in range(n): | |
| direction = 1 if random.random() > 0.5 else -1 | |
| move = random.uniform(high_vol, high_vol * 2) * direction | |
| o = price + random.gauss(0, high_vol) | |
| c = o + move | |
| h = max(o, c) + random.uniform(high_vol * 0.5, high_vol) | |
| l = min(o, c) - random.uniform(high_vol * 0.5, high_vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return candles | |
| def _bullish_reversal(self, n, vol): | |
| mid = n // 2 | |
| part1 = self._bearish_trend(mid, vol) | |
| last = part1[-1]["c"] | |
| part2 = [] | |
| price = last | |
| for i in range(n - mid): | |
| trend = random.uniform(0.5, 1.5) | |
| o = price + random.gauss(0, vol) | |
| c = o + random.uniform(0, vol * 2) + trend | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| part2.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return part1 + part2 | |
| def _bearish_reversal(self, n, vol): | |
| mid = n // 2 | |
| part1 = self._bullish_trend(mid, vol) | |
| last = part1[-1]["c"] | |
| part2 = [] | |
| price = last | |
| for i in range(n - mid): | |
| trend = random.uniform(0.5, 1.5) | |
| o = price + random.gauss(0, vol) | |
| c = o - random.uniform(0, vol * 2) - trend | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| part2.append({"o": o, "h": h, "l": l, "c": c}) | |
| price = c | |
| return part1 + part2 | |
| def _double_top(self, n, vol): | |
| third = n // 3 | |
| candles = [] | |
| base, peak = 100, 120 | |
| for i in range(third): | |
| p = base + (peak - base) * (i / third) + random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| for i in range(third): | |
| p = peak - (peak - base) * 0.5 * (i / third) + random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| for i in range(n - 2 * third): | |
| prog = i / (n - 2 * third) | |
| if prog < 0.5: | |
| p = (base + peak) / 2 + (peak - (base + peak) / 2) * (prog * 2) | |
| else: | |
| p = peak - (peak - base) * ((prog - 0.5) * 2) | |
| p += random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| return candles | |
| def _double_bottom(self, n, vol): | |
| third = n // 3 | |
| candles = [] | |
| base, bottom = 120, 100 | |
| for i in range(third): | |
| p = base - (base - bottom) * (i / third) + random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| for i in range(third): | |
| p = bottom + (base - bottom) * 0.5 * (i / third) + random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| for i in range(n - 2 * third): | |
| prog = i / (n - 2 * third) | |
| if prog < 0.5: | |
| p = (base + bottom) / 2 - ((base + bottom) / 2 - bottom) * (prog * 2) | |
| else: | |
| p = bottom + (base - bottom) * ((prog - 0.5) * 2) | |
| p += random.gauss(0, vol) | |
| o, c = p, p + random.uniform(-vol, vol) | |
| h = max(o, c) + random.uniform(0, vol) | |
| l = min(o, c) - random.uniform(0, vol) | |
| candles.append({"o": o, "h": h, "l": l, "c": c}) | |
| return candles | |
| def render(self, candles): | |
| fig, ax = plt.subplots(figsize=(self.image_size[0]/100, self.image_size[1]/100), dpi=100) | |
| fig.patch.set_facecolor(self.bg_color) | |
| ax.set_facecolor(self.bg_color) | |
| highs = [c["h"] for c in candles] | |
| lows = [c["l"] for c in candles] | |
| price_min = min(lows) * 0.98 | |
| price_max = max(highs) * 1.02 | |
| width = 0.6 | |
| for i, c in enumerate(candles): | |
| color = self.bullish_color if c["c"] >= c["o"] else self.bearish_color | |
| ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1) | |
| body_bottom = min(c["o"], c["c"]) | |
| body_height = abs(c["c"] - c["o"]) or 0.1 | |
| rect = Rectangle((i - width/2, body_bottom), width, body_height, | |
| facecolor=color, edgecolor=color) | |
| ax.add_patch(rect) | |
| ax.set_xlim(-1, len(candles)) | |
| ax.set_ylim(price_min, price_max) | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", facecolor=self.bg_color, | |
| bbox_inches="tight", pad_inches=0.1) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf).convert("RGB") | |
| img = img.resize(self.image_size, Image.Resampling.LANCZOS) | |
| return img | |
| def generate_sample(self): | |
| pattern = random.choice(list(self.patterns.keys())) | |
| vol_name = random.choice(["low", "medium", "high"]) | |
| vol_map = {"low": 1.0, "medium": 3.0, "high": 6.0} | |
| candles = self.patterns[pattern](self.num_candles, vol_map[vol_name]) | |
| image = self.render(candles) | |
| descriptions = { | |
| "bullish_trend": [f"bullish trend {vol_name} volatility", f"upward trending market {vol_name} movement", "strong buying pressure"], | |
| "bearish_trend": [f"bearish trend {vol_name} volatility", f"downward trending market {vol_name} movement", "strong selling pressure"], | |
| "sideways": [f"sideways market {vol_name} volatility", "range-bound trading", "consolidation pattern"], | |
| "volatile": ["highly volatile market", "erratic price movement", "choppy market conditions"], | |
| "bullish_reversal": [f"bullish reversal {vol_name} volatility", "v-shaped recovery", "trend change bearish to bullish"], | |
| "bearish_reversal": [f"bearish reversal {vol_name} volatility", "inverted v pattern", "trend change bullish to bearish"], | |
| "double_top": [f"double top pattern {vol_name} volatility", "m-shaped reversal", "bearish double top"], | |
| "double_bottom": [f"double bottom pattern {vol_name} volatility", "w-shaped reversal", "bullish double bottom"], | |
| } | |
| description = random.choice(descriptions[pattern]) | |
| return image, description | |
| def generate_dataset(output_dir, num_samples=10000, image_size=128): | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, "images"), exist_ok=True) | |
| generator = CandlestickGenerator(image_size=(image_size, image_size)) | |
| labels = {} | |
| print(f"Generating {num_samples} samples...") | |
| for i in tqdm(range(num_samples)): | |
| image, description = generator.generate_sample() | |
| filename = f"chart_{i:06d}.png" | |
| image.save(os.path.join(output_dir, "images", filename)) | |
| labels[filename] = description | |
| with open(os.path.join(output_dir, "labels.json"), "w") as f: | |
| json.dump(labels, f, indent=2) | |
| print(f"✅ Dataset saved to {output_dir}") | |
| print(f" - {num_samples} images") | |
| print(f" - Labels in labels.json") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--num_samples", type=int, default=10000) | |
| parser.add_argument("--output_dir", type=str, default="./dataset") | |
| parser.add_argument("--image_size", type=int, default=128) | |
| args = parser.parse_args() | |
| generate_dataset(args.output_dir, args.num_samples, args.image_size) | |