Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -436,6 +436,106 @@ def load_model(checkpoint_path=None):
|
|
| 436 |
return True
|
| 437 |
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# ============== Gradio Interface ==============
|
| 440 |
|
| 441 |
def generate_chart(prompt, num_steps, guidance_scale, seed):
|
|
@@ -592,6 +692,30 @@ def create_demo():
|
|
| 592 |
""")
|
| 593 |
|
| 594 |
with gr.Tabs():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
# Generation Tab
|
| 596 |
with gr.TabItem("π¨ Generate"):
|
| 597 |
with gr.Row():
|
|
|
|
| 436 |
return True
|
| 437 |
|
| 438 |
|
| 439 |
+
def generate_dataset_ui(num_samples, image_size):
|
| 440 |
+
"""Generate training dataset."""
|
| 441 |
+
try:
|
| 442 |
+
import os
|
| 443 |
+
import json
|
| 444 |
+
import random
|
| 445 |
+
from PIL import Image
|
| 446 |
+
import matplotlib.pyplot as plt
|
| 447 |
+
from matplotlib.patches import Rectangle
|
| 448 |
+
import io
|
| 449 |
+
|
| 450 |
+
output_dir = "./dataset"
|
| 451 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 452 |
+
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
| 453 |
+
|
| 454 |
+
bg_color = "#1a1a2e"
|
| 455 |
+
bullish_color = "#00ff88"
|
| 456 |
+
bearish_color = "#ff4466"
|
| 457 |
+
num_candles = 20
|
| 458 |
+
|
| 459 |
+
def generate_candles(pattern, vol):
|
| 460 |
+
candles = []
|
| 461 |
+
price = 100 if pattern != "bearish" else 150
|
| 462 |
+
|
| 463 |
+
for i in range(num_candles):
|
| 464 |
+
if pattern == "bullish":
|
| 465 |
+
trend = random.uniform(0.5, 2.0)
|
| 466 |
+
o = price + random.gauss(0, vol)
|
| 467 |
+
c = o + random.uniform(0, vol*2) + trend
|
| 468 |
+
elif pattern == "bearish":
|
| 469 |
+
trend = random.uniform(0.5, 2.0)
|
| 470 |
+
o = price + random.gauss(0, vol)
|
| 471 |
+
c = o - random.uniform(0, vol*2) - trend
|
| 472 |
+
else: # sideways
|
| 473 |
+
o = price + random.gauss(0, vol)
|
| 474 |
+
c = o + random.gauss(0, vol)
|
| 475 |
+
|
| 476 |
+
h = max(o, c) + random.uniform(0, vol)
|
| 477 |
+
l = min(o, c) - random.uniform(0, vol)
|
| 478 |
+
candles.append({"o": o, "h": h, "l": l, "c": c})
|
| 479 |
+
price = c
|
| 480 |
+
return candles
|
| 481 |
+
|
| 482 |
+
def render(candles):
|
| 483 |
+
fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100)
|
| 484 |
+
fig.patch.set_facecolor(bg_color)
|
| 485 |
+
ax.set_facecolor(bg_color)
|
| 486 |
+
|
| 487 |
+
highs = [c["h"] for c in candles]
|
| 488 |
+
lows = [c["l"] for c in candles]
|
| 489 |
+
price_min, price_max = min(lows)*0.98, max(highs)*1.02
|
| 490 |
+
|
| 491 |
+
for i, c in enumerate(candles):
|
| 492 |
+
color = bullish_color if c["c"] >= c["o"] else bearish_color
|
| 493 |
+
ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
|
| 494 |
+
body_bottom = min(c["o"], c["c"])
|
| 495 |
+
body_height = abs(c["c"] - c["o"]) or 0.1
|
| 496 |
+
rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
|
| 497 |
+
ax.add_patch(rect)
|
| 498 |
+
|
| 499 |
+
ax.set_xlim(-1, len(candles))
|
| 500 |
+
ax.set_ylim(price_min, price_max)
|
| 501 |
+
ax.axis("off")
|
| 502 |
+
|
| 503 |
+
buf = io.BytesIO()
|
| 504 |
+
plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1)
|
| 505 |
+
plt.close(fig)
|
| 506 |
+
buf.seek(0)
|
| 507 |
+
|
| 508 |
+
img = Image.open(buf).convert("RGB")
|
| 509 |
+
return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
| 510 |
+
|
| 511 |
+
patterns = ["bullish", "bearish", "sideways"]
|
| 512 |
+
volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
|
| 513 |
+
labels = {}
|
| 514 |
+
|
| 515 |
+
for i in range(int(num_samples)):
|
| 516 |
+
pattern = random.choice(patterns)
|
| 517 |
+
vol_name = random.choice(list(volatilities.keys()))
|
| 518 |
+
vol = volatilities[vol_name]
|
| 519 |
+
|
| 520 |
+
candles = generate_candles(pattern, vol)
|
| 521 |
+
img = render(candles)
|
| 522 |
+
|
| 523 |
+
filename = f"chart_{i:06d}.png"
|
| 524 |
+
img.save(os.path.join(output_dir, "images", filename))
|
| 525 |
+
labels[filename] = f"{pattern} trend {vol_name} volatility"
|
| 526 |
+
|
| 527 |
+
if i % 500 == 0:
|
| 528 |
+
print(f"Generated {i}/{num_samples}")
|
| 529 |
+
|
| 530 |
+
with open(os.path.join(output_dir, "labels.json"), "w") as f:
|
| 531 |
+
json.dump(labels, f)
|
| 532 |
+
|
| 533 |
+
return f"β
Generated {num_samples} samples in ./dataset"
|
| 534 |
+
|
| 535 |
+
except Exception as e:
|
| 536 |
+
return f"β Failed: {str(e)}"
|
| 537 |
+
|
| 538 |
+
|
| 539 |
# ============== Gradio Interface ==============
|
| 540 |
|
| 541 |
def generate_chart(prompt, num_steps, guidance_scale, seed):
|
|
|
|
| 692 |
""")
|
| 693 |
|
| 694 |
with gr.Tabs():
|
| 695 |
+
# Data Generation Tab
|
| 696 |
+
with gr.TabItem("π Generate Data"):
|
| 697 |
+
gr.Markdown("""
|
| 698 |
+
### Generate Training Dataset
|
| 699 |
+
|
| 700 |
+
Create synthetic candlestick chart images for training.
|
| 701 |
+
**Run this first before training!**
|
| 702 |
+
""")
|
| 703 |
+
|
| 704 |
+
with gr.Row():
|
| 705 |
+
with gr.Column():
|
| 706 |
+
num_samples = gr.Slider(1000, 50000, value=10000, step=1000, label="Number of Samples")
|
| 707 |
+
data_image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size")
|
| 708 |
+
generate_data_btn = gr.Button("π Generate Dataset", variant="primary")
|
| 709 |
+
|
| 710 |
+
with gr.Column():
|
| 711 |
+
data_status = gr.Textbox(label="Status", lines=5, interactive=False)
|
| 712 |
+
|
| 713 |
+
generate_data_btn.click(
|
| 714 |
+
generate_dataset_ui,
|
| 715 |
+
inputs=[num_samples, data_image_size],
|
| 716 |
+
outputs=[data_status]
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
# Generation Tab
|
| 720 |
with gr.TabItem("π¨ Generate"):
|
| 721 |
with gr.Row():
|