Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -478,25 +478,33 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 478 |
import os
|
| 479 |
import json
|
| 480 |
import random
|
| 481 |
-
from PIL import Image
|
| 482 |
-
import matplotlib.pyplot as plt
|
| 483 |
-
from matplotlib.patches import Rectangle
|
| 484 |
-
import io
|
| 485 |
|
| 486 |
output_dir = "./dataset"
|
| 487 |
os.makedirs(output_dir, exist_ok=True)
|
| 488 |
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
| 489 |
|
| 490 |
-
bg_color = "#1a1a2e"
|
| 491 |
-
bullish_color = "#00ff88"
|
| 492 |
-
bearish_color = "#ff4466"
|
| 493 |
num_candles = 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
candles = []
|
| 497 |
price = 100 if pattern != "bearish" else 150
|
| 498 |
|
| 499 |
-
for
|
| 500 |
if pattern == "bullish":
|
| 501 |
trend = random.uniform(0.5, 2.0)
|
| 502 |
o = price + random.gauss(0, vol)
|
|
@@ -505,7 +513,7 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 505 |
trend = random.uniform(0.5, 2.0)
|
| 506 |
o = price + random.gauss(0, vol)
|
| 507 |
c = o - random.uniform(0, vol*2) - trend
|
| 508 |
-
else:
|
| 509 |
o = price + random.gauss(0, vol)
|
| 510 |
c = o + random.gauss(0, vol)
|
| 511 |
|
|
@@ -513,63 +521,62 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 513 |
l = min(o, c) - random.uniform(0, vol)
|
| 514 |
candles.append({"o": o, "h": h, "l": l, "c": c})
|
| 515 |
price = c
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
fig.patch.set_facecolor(bg_color)
|
| 521 |
-
ax.set_facecolor(bg_color)
|
| 522 |
|
| 523 |
highs = [c["h"] for c in candles]
|
| 524 |
lows = [c["l"] for c in candles]
|
| 525 |
-
price_min, price_max = min(lows)
|
| 526 |
-
|
| 527 |
-
for i, c in enumerate(candles):
|
| 528 |
-
color = bullish_color if c["c"] >= c["o"] else bearish_color
|
| 529 |
-
ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
|
| 530 |
-
body_bottom = min(c["o"], c["c"])
|
| 531 |
-
body_height = abs(c["c"] - c["o"]) or 0.1
|
| 532 |
-
rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
|
| 533 |
-
ax.add_patch(rect)
|
| 534 |
-
|
| 535 |
-
ax.set_xlim(-1, len(candles))
|
| 536 |
-
ax.set_ylim(price_min, price_max)
|
| 537 |
-
ax.axis("off")
|
| 538 |
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
img = Image.open(buf).convert("RGB")
|
| 545 |
-
return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
| 546 |
-
|
| 547 |
-
patterns = ["bullish", "bearish", "sideways"]
|
| 548 |
-
volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
|
| 549 |
-
labels = {}
|
| 550 |
-
|
| 551 |
-
for i in range(int(num_samples)):
|
| 552 |
-
pattern = random.choice(patterns)
|
| 553 |
-
vol_name = random.choice(list(volatilities.keys()))
|
| 554 |
-
vol = volatilities[vol_name]
|
| 555 |
|
| 556 |
-
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
|
|
|
| 559 |
filename = f"chart_{i:06d}.png"
|
| 560 |
img.save(os.path.join(output_dir, "images", filename))
|
| 561 |
labels[filename] = f"{pattern} trend {vol_name} volatility"
|
| 562 |
-
|
| 563 |
-
if i % 500 == 0:
|
| 564 |
-
print(f"Generated {i}/{num_samples}")
|
| 565 |
|
|
|
|
| 566 |
with open(os.path.join(output_dir, "labels.json"), "w") as f:
|
| 567 |
json.dump(labels, f)
|
| 568 |
|
| 569 |
-
return f"✅ Generated {num_samples} samples in ./dataset"
|
| 570 |
|
| 571 |
except Exception as e:
|
| 572 |
-
|
|
|
|
| 573 |
|
| 574 |
|
| 575 |
# ============== Gradio Interface ==============
|
|
|
|
| 478 |
import os
|
| 479 |
import json
|
| 480 |
import random
|
| 481 |
+
from PIL import Image, ImageDraw
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
output_dir = "./dataset"
|
| 484 |
os.makedirs(output_dir, exist_ok=True)
|
| 485 |
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
| 486 |
|
|
|
|
|
|
|
|
|
|
| 487 |
num_candles = 20
|
| 488 |
+
labels = {}
|
| 489 |
+
|
| 490 |
+
# Colors
|
| 491 |
+
bg_color = (26, 26, 46)
|
| 492 |
+
bullish_color = (0, 255, 136)
|
| 493 |
+
bearish_color = (255, 68, 102)
|
| 494 |
+
|
| 495 |
+
patterns = ["bullish", "bearish", "sideways"]
|
| 496 |
+
volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
|
| 497 |
|
| 498 |
+
for i in range(int(num_samples)):
|
| 499 |
+
# Generate candle data
|
| 500 |
+
pattern = random.choice(patterns)
|
| 501 |
+
vol_name = random.choice(list(volatilities.keys()))
|
| 502 |
+
vol = volatilities[vol_name]
|
| 503 |
+
|
| 504 |
candles = []
|
| 505 |
price = 100 if pattern != "bearish" else 150
|
| 506 |
|
| 507 |
+
for j in range(num_candles):
|
| 508 |
if pattern == "bullish":
|
| 509 |
trend = random.uniform(0.5, 2.0)
|
| 510 |
o = price + random.gauss(0, vol)
|
|
|
|
| 513 |
trend = random.uniform(0.5, 2.0)
|
| 514 |
o = price + random.gauss(0, vol)
|
| 515 |
c = o - random.uniform(0, vol*2) - trend
|
| 516 |
+
else:
|
| 517 |
o = price + random.gauss(0, vol)
|
| 518 |
c = o + random.gauss(0, vol)
|
| 519 |
|
|
|
|
| 521 |
l = min(o, c) - random.uniform(0, vol)
|
| 522 |
candles.append({"o": o, "h": h, "l": l, "c": c})
|
| 523 |
price = c
|
| 524 |
+
|
| 525 |
+
# Render with PIL (no matplotlib)
|
| 526 |
+
img = Image.new('RGB', (image_size, image_size), bg_color)
|
| 527 |
+
draw = ImageDraw.Draw(img)
|
|
|
|
|
|
|
| 528 |
|
| 529 |
highs = [c["h"] for c in candles]
|
| 530 |
lows = [c["l"] for c in candles]
|
| 531 |
+
price_min, price_max = min(lows), max(highs)
|
| 532 |
+
price_range = price_max - price_min or 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
+
padding = 10
|
| 535 |
+
chart_width = image_size - 2 * padding
|
| 536 |
+
chart_height = image_size - 2 * padding
|
| 537 |
+
candle_width = chart_width // (num_candles + 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
+
for j, c in enumerate(candles):
|
| 540 |
+
x = padding + (j + 1) * candle_width
|
| 541 |
+
|
| 542 |
+
# Scale prices to pixels
|
| 543 |
+
def scale_y(p):
|
| 544 |
+
return padding + chart_height - int((p - price_min) / price_range * chart_height)
|
| 545 |
+
|
| 546 |
+
y_open = scale_y(c["o"])
|
| 547 |
+
y_close = scale_y(c["c"])
|
| 548 |
+
y_high = scale_y(c["h"])
|
| 549 |
+
y_low = scale_y(c["l"])
|
| 550 |
+
|
| 551 |
+
color = bullish_color if c["c"] >= c["o"] else bearish_color
|
| 552 |
+
|
| 553 |
+
# Draw wick
|
| 554 |
+
draw.line([(x, y_high), (x, y_low)], fill=color, width=1)
|
| 555 |
+
|
| 556 |
+
# Draw body
|
| 557 |
+
body_top = min(y_open, y_close)
|
| 558 |
+
body_bottom = max(y_open, y_close)
|
| 559 |
+
if body_bottom - body_top < 2:
|
| 560 |
+
body_bottom = body_top + 2
|
| 561 |
+
draw.rectangle(
|
| 562 |
+
[(x - candle_width//3, body_top), (x + candle_width//3, body_bottom)],
|
| 563 |
+
fill=color
|
| 564 |
+
)
|
| 565 |
|
| 566 |
+
# Save image
|
| 567 |
filename = f"chart_{i:06d}.png"
|
| 568 |
img.save(os.path.join(output_dir, "images", filename))
|
| 569 |
labels[filename] = f"{pattern} trend {vol_name} volatility"
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
+
# Save labels
|
| 572 |
with open(os.path.join(output_dir, "labels.json"), "w") as f:
|
| 573 |
json.dump(labels, f)
|
| 574 |
|
| 575 |
+
return f"✅ Generated {int(num_samples)} samples in ./dataset"
|
| 576 |
|
| 577 |
except Exception as e:
|
| 578 |
+
import traceback
|
| 579 |
+
return f"❌ Failed: {str(e)}\n{traceback.format_exc()}"
|
| 580 |
|
| 581 |
|
| 582 |
# ============== Gradio Interface ==============
|