Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -390,42 +390,6 @@ DEVICE = None
|
|
| 390 |
CONFIG = None
|
| 391 |
|
| 392 |
|
| 393 |
-
def save_to_hub(save_name, repo_id=None):
|
| 394 |
-
"""Save model checkpoint to HuggingFace Hub for persistence."""
|
| 395 |
-
global MODEL, TEXT_ENCODER, CONFIG
|
| 396 |
-
|
| 397 |
-
if MODEL is None:
|
| 398 |
-
return "β No model loaded to save"
|
| 399 |
-
|
| 400 |
-
try:
|
| 401 |
-
from huggingface_hub import HfApi, upload_file
|
| 402 |
-
import tempfile
|
| 403 |
-
|
| 404 |
-
# Save to temp file
|
| 405 |
-
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
|
| 406 |
-
torch.save({
|
| 407 |
-
"model_state_dict": MODEL.state_dict(),
|
| 408 |
-
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 409 |
-
"config": CONFIG
|
| 410 |
-
}, f.name)
|
| 411 |
-
temp_path = f.name
|
| 412 |
-
|
| 413 |
-
# Upload to Hub (same Space repo)
|
| 414 |
-
api = HfApi()
|
| 415 |
-
api.upload_file(
|
| 416 |
-
path_or_fileobj=temp_path,
|
| 417 |
-
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 418 |
-
repo_id=repo_id or "Spanicin/candlestick-diffusion",
|
| 419 |
-
repo_type="space"
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
os.unlink(temp_path)
|
| 423 |
-
return f"β
Model saved to Hub: checkpoints/{save_name}.pt"
|
| 424 |
-
|
| 425 |
-
except Exception as e:
|
| 426 |
-
return f"β Failed to save to Hub: {str(e)}"
|
| 427 |
-
|
| 428 |
-
|
| 429 |
def load_model(checkpoint_path=None):
|
| 430 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 431 |
|
|
@@ -478,33 +442,25 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 478 |
import os
|
| 479 |
import json
|
| 480 |
import random
|
| 481 |
-
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
output_dir = "./dataset"
|
| 484 |
os.makedirs(output_dir, exist_ok=True)
|
| 485 |
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
| 486 |
|
|
|
|
|
|
|
|
|
|
| 487 |
num_candles = 20
|
| 488 |
-
labels = {}
|
| 489 |
-
|
| 490 |
-
# Colors
|
| 491 |
-
bg_color = (26, 26, 46)
|
| 492 |
-
bullish_color = (0, 255, 136)
|
| 493 |
-
bearish_color = (255, 68, 102)
|
| 494 |
-
|
| 495 |
-
patterns = ["bullish", "bearish", "sideways"]
|
| 496 |
-
volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
|
| 497 |
|
| 498 |
-
|
| 499 |
-
# Generate candle data
|
| 500 |
-
pattern = random.choice(patterns)
|
| 501 |
-
vol_name = random.choice(list(volatilities.keys()))
|
| 502 |
-
vol = volatilities[vol_name]
|
| 503 |
-
|
| 504 |
candles = []
|
| 505 |
price = 100 if pattern != "bearish" else 150
|
| 506 |
|
| 507 |
-
for
|
| 508 |
if pattern == "bullish":
|
| 509 |
trend = random.uniform(0.5, 2.0)
|
| 510 |
o = price + random.gauss(0, vol)
|
|
@@ -513,7 +469,7 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 513 |
trend = random.uniform(0.5, 2.0)
|
| 514 |
o = price + random.gauss(0, vol)
|
| 515 |
c = o - random.uniform(0, vol*2) - trend
|
| 516 |
-
else:
|
| 517 |
o = price + random.gauss(0, vol)
|
| 518 |
c = o + random.gauss(0, vol)
|
| 519 |
|
|
@@ -521,62 +477,63 @@ def generate_dataset_ui(num_samples, image_size):
|
|
| 521 |
l = min(o, c) - random.uniform(0, vol)
|
| 522 |
candles.append({"o": o, "h": h, "l": l, "c": c})
|
| 523 |
price = c
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
|
|
|
|
|
|
| 528 |
|
| 529 |
highs = [c["h"] for c in candles]
|
| 530 |
lows = [c["l"] for c in candles]
|
| 531 |
-
price_min, price_max = min(lows), max(highs)
|
| 532 |
-
price_range = price_max - price_min or 1
|
| 533 |
-
|
| 534 |
-
padding = 10
|
| 535 |
-
chart_width = image_size - 2 * padding
|
| 536 |
-
chart_height = image_size - 2 * padding
|
| 537 |
-
candle_width = chart_width // (num_candles + 2)
|
| 538 |
|
| 539 |
-
for
|
| 540 |
-
x = padding + (j + 1) * candle_width
|
| 541 |
-
|
| 542 |
-
# Scale prices to pixels
|
| 543 |
-
def scale_y(p):
|
| 544 |
-
return padding + chart_height - int((p - price_min) / price_range * chart_height)
|
| 545 |
-
|
| 546 |
-
y_open = scale_y(c["o"])
|
| 547 |
-
y_close = scale_y(c["c"])
|
| 548 |
-
y_high = scale_y(c["h"])
|
| 549 |
-
y_low = scale_y(c["l"])
|
| 550 |
-
|
| 551 |
color = bullish_color if c["c"] >= c["o"] else bearish_color
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
-
# Save image
|
| 567 |
filename = f"chart_{i:06d}.png"
|
| 568 |
img.save(os.path.join(output_dir, "images", filename))
|
| 569 |
labels[filename] = f"{pattern} trend {vol_name} volatility"
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
-
# Save labels
|
| 572 |
with open(os.path.join(output_dir, "labels.json"), "w") as f:
|
| 573 |
json.dump(labels, f)
|
| 574 |
|
| 575 |
-
return f"β
Generated {
|
| 576 |
|
| 577 |
except Exception as e:
|
| 578 |
-
|
| 579 |
-
return f"β Failed: {str(e)}\n{traceback.format_exc()}"
|
| 580 |
|
| 581 |
|
| 582 |
# ============== Gradio Interface ==============
|
|
@@ -625,20 +582,12 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 625 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 626 |
|
| 627 |
try:
|
| 628 |
-
# Clear GPU memory
|
| 629 |
-
import gc
|
| 630 |
-
gc.collect()
|
| 631 |
-
if torch.cuda.is_available():
|
| 632 |
-
torch.cuda.empty_cache()
|
| 633 |
-
|
| 634 |
# Setup
|
| 635 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 636 |
-
|
| 637 |
-
# Use smaller model for T4 GPU
|
| 638 |
CONFIG = {
|
| 639 |
-
"base_channels":
|
| 640 |
-
"channel_mults": (1, 2, 4),
|
| 641 |
-
"context_dim":
|
| 642 |
"image_size": image_size,
|
| 643 |
"timesteps": 1000
|
| 644 |
}
|
|
@@ -674,14 +623,13 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 674 |
logs = [f"π Training started on {DEVICE}"]
|
| 675 |
logs.append(f"π Model parameters: {num_params:,}")
|
| 676 |
logs.append(f"π Training samples: {len(train_dataset)}")
|
| 677 |
-
logs.append(f"πΌοΈ Image size: {image_size}x{image_size}")
|
| 678 |
-
logs.append(f"π¦ Batch size: {batch_size}")
|
| 679 |
logs.append("-" * 40)
|
| 680 |
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
| 682 |
epoch_loss = 0
|
| 683 |
-
batch_count = 0
|
| 684 |
-
|
| 685 |
for images, texts in train_loader:
|
| 686 |
images = images.to(DEVICE)
|
| 687 |
context = TEXT_ENCODER(texts, DEVICE)
|
|
@@ -693,28 +641,10 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 693 |
optimizer.step()
|
| 694 |
|
| 695 |
epoch_loss += loss.item()
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
# Clear cache periodically
|
| 699 |
-
if batch_count % 50 == 0:
|
| 700 |
-
if torch.cuda.is_available():
|
| 701 |
-
torch.cuda.empty_cache()
|
| 702 |
|
| 703 |
-
avg_loss = epoch_loss /
|
| 704 |
-
logs.append(f"Epoch {epoch+1}/{
|
| 705 |
-
|
| 706 |
-
# Save checkpoint every 10 epochs
|
| 707 |
-
if (epoch + 1) % 10 == 0:
|
| 708 |
-
print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
|
| 709 |
-
checkpoint_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
|
| 710 |
-
os.makedirs("checkpoints", exist_ok=True)
|
| 711 |
-
torch.save({
|
| 712 |
-
"model_state_dict": MODEL.state_dict(),
|
| 713 |
-
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 714 |
-
"config": CONFIG,
|
| 715 |
-
"epoch": epoch + 1
|
| 716 |
-
}, checkpoint_path)
|
| 717 |
-
logs.append(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 718 |
|
| 719 |
# Save model
|
| 720 |
MODEL.eval()
|
|
@@ -729,26 +659,10 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 729 |
logs.append("-" * 40)
|
| 730 |
logs.append(f"β
Model saved to {save_path}")
|
| 731 |
|
| 732 |
-
# Also try to save to Hub for persistence
|
| 733 |
-
try:
|
| 734 |
-
from huggingface_hub import HfApi
|
| 735 |
-
api = HfApi()
|
| 736 |
-
api.upload_file(
|
| 737 |
-
path_or_fileobj=save_path,
|
| 738 |
-
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 739 |
-
repo_id="Spanicin/candlestick-diffusion",
|
| 740 |
-
repo_type="space"
|
| 741 |
-
)
|
| 742 |
-
logs.append(f"βοΈ Model uploaded to Hub (persistent)")
|
| 743 |
-
except Exception as hub_error:
|
| 744 |
-
logs.append(f"β οΈ Could not upload to Hub: {hub_error}")
|
| 745 |
-
logs.append(" Model saved locally but may be lost on restart")
|
| 746 |
-
|
| 747 |
return "\n".join(logs)
|
| 748 |
|
| 749 |
except Exception as e:
|
| 750 |
-
|
| 751 |
-
return f"β Training failed: {str(e)}\n{traceback.format_exc()}"
|
| 752 |
|
| 753 |
|
| 754 |
def load_checkpoint(checkpoint_file):
|
|
|
|
| 390 |
CONFIG = None
|
| 391 |
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
def load_model(checkpoint_path=None):
|
| 394 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 395 |
|
|
|
|
| 442 |
import os
|
| 443 |
import json
|
| 444 |
import random
|
| 445 |
+
from PIL import Image
|
| 446 |
+
import matplotlib.pyplot as plt
|
| 447 |
+
from matplotlib.patches import Rectangle
|
| 448 |
+
import io
|
| 449 |
|
| 450 |
output_dir = "./dataset"
|
| 451 |
os.makedirs(output_dir, exist_ok=True)
|
| 452 |
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
| 453 |
|
| 454 |
+
bg_color = "#1a1a2e"
|
| 455 |
+
bullish_color = "#00ff88"
|
| 456 |
+
bearish_color = "#ff4466"
|
| 457 |
num_candles = 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
def generate_candles(pattern, vol):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
candles = []
|
| 461 |
price = 100 if pattern != "bearish" else 150
|
| 462 |
|
| 463 |
+
for i in range(num_candles):
|
| 464 |
if pattern == "bullish":
|
| 465 |
trend = random.uniform(0.5, 2.0)
|
| 466 |
o = price + random.gauss(0, vol)
|
|
|
|
| 469 |
trend = random.uniform(0.5, 2.0)
|
| 470 |
o = price + random.gauss(0, vol)
|
| 471 |
c = o - random.uniform(0, vol*2) - trend
|
| 472 |
+
else: # sideways
|
| 473 |
o = price + random.gauss(0, vol)
|
| 474 |
c = o + random.gauss(0, vol)
|
| 475 |
|
|
|
|
| 477 |
l = min(o, c) - random.uniform(0, vol)
|
| 478 |
candles.append({"o": o, "h": h, "l": l, "c": c})
|
| 479 |
price = c
|
| 480 |
+
return candles
|
| 481 |
+
|
| 482 |
+
def render(candles):
|
| 483 |
+
fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100)
|
| 484 |
+
fig.patch.set_facecolor(bg_color)
|
| 485 |
+
ax.set_facecolor(bg_color)
|
| 486 |
|
| 487 |
highs = [c["h"] for c in candles]
|
| 488 |
lows = [c["l"] for c in candles]
|
| 489 |
+
price_min, price_max = min(lows)*0.98, max(highs)*1.02
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
+
for i, c in enumerate(candles):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
color = bullish_color if c["c"] >= c["o"] else bearish_color
|
| 493 |
+
ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
|
| 494 |
+
body_bottom = min(c["o"], c["c"])
|
| 495 |
+
body_height = abs(c["c"] - c["o"]) or 0.1
|
| 496 |
+
rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
|
| 497 |
+
ax.add_patch(rect)
|
| 498 |
+
|
| 499 |
+
ax.set_xlim(-1, len(candles))
|
| 500 |
+
ax.set_ylim(price_min, price_max)
|
| 501 |
+
ax.axis("off")
|
| 502 |
+
|
| 503 |
+
buf = io.BytesIO()
|
| 504 |
+
plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1)
|
| 505 |
+
plt.close(fig)
|
| 506 |
+
buf.seek(0)
|
| 507 |
+
|
| 508 |
+
img = Image.open(buf).convert("RGB")
|
| 509 |
+
return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
| 510 |
+
|
| 511 |
+
patterns = ["bullish", "bearish", "sideways"]
|
| 512 |
+
volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
|
| 513 |
+
labels = {}
|
| 514 |
+
|
| 515 |
+
for i in range(int(num_samples)):
|
| 516 |
+
pattern = random.choice(patterns)
|
| 517 |
+
vol_name = random.choice(list(volatilities.keys()))
|
| 518 |
+
vol = volatilities[vol_name]
|
| 519 |
+
|
| 520 |
+
candles = generate_candles(pattern, vol)
|
| 521 |
+
img = render(candles)
|
| 522 |
|
|
|
|
| 523 |
filename = f"chart_{i:06d}.png"
|
| 524 |
img.save(os.path.join(output_dir, "images", filename))
|
| 525 |
labels[filename] = f"{pattern} trend {vol_name} volatility"
|
| 526 |
+
|
| 527 |
+
if i % 500 == 0:
|
| 528 |
+
print(f"Generated {i}/{num_samples}")
|
| 529 |
|
|
|
|
| 530 |
with open(os.path.join(output_dir, "labels.json"), "w") as f:
|
| 531 |
json.dump(labels, f)
|
| 532 |
|
| 533 |
+
return f"β
Generated {num_samples} samples in ./dataset"
|
| 534 |
|
| 535 |
except Exception as e:
|
| 536 |
+
return f"β Failed: {str(e)}"
|
|
|
|
| 537 |
|
| 538 |
|
| 539 |
# ============== Gradio Interface ==============
|
|
|
|
| 582 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 583 |
|
| 584 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
# Setup
|
| 586 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
| 587 |
CONFIG = {
|
| 588 |
+
"base_channels": 64,
|
| 589 |
+
"channel_mults": (1, 2, 4),
|
| 590 |
+
"context_dim": 256,
|
| 591 |
"image_size": image_size,
|
| 592 |
"timesteps": 1000
|
| 593 |
}
|
|
|
|
| 623 |
logs = [f"π Training started on {DEVICE}"]
|
| 624 |
logs.append(f"π Model parameters: {num_params:,}")
|
| 625 |
logs.append(f"π Training samples: {len(train_dataset)}")
|
|
|
|
|
|
|
| 626 |
logs.append("-" * 40)
|
| 627 |
|
| 628 |
+
total_steps = epochs * len(train_loader)
|
| 629 |
+
current_step = 0
|
| 630 |
+
|
| 631 |
+
for epoch in range(epochs):
|
| 632 |
epoch_loss = 0
|
|
|
|
|
|
|
| 633 |
for images, texts in train_loader:
|
| 634 |
images = images.to(DEVICE)
|
| 635 |
context = TEXT_ENCODER(texts, DEVICE)
|
|
|
|
| 641 |
optimizer.step()
|
| 642 |
|
| 643 |
epoch_loss += loss.item()
|
| 644 |
+
current_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
+
avg_loss = epoch_loss / len(train_loader)
|
| 647 |
+
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
# Save model
|
| 650 |
MODEL.eval()
|
|
|
|
| 659 |
logs.append("-" * 40)
|
| 660 |
logs.append(f"β
Model saved to {save_path}")
|
| 661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
return "\n".join(logs)
|
| 663 |
|
| 664 |
except Exception as e:
|
| 665 |
+
return f"β Training failed: {str(e)}"
|
|
|
|
| 666 |
|
| 667 |
|
| 668 |
def load_checkpoint(checkpoint_file):
|