Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -607,7 +607,7 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 607 |
train_dataset = ChartDataset(data_path, image_size=image_size, split="train")
|
| 608 |
train_loader = DataLoader(
|
| 609 |
train_dataset, batch_size=batch_size, shuffle=True,
|
| 610 |
-
num_workers=
|
| 611 |
)
|
| 612 |
|
| 613 |
# Optimizer
|
|
@@ -642,12 +642,9 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 642 |
|
| 643 |
epoch_loss += loss.item()
|
| 644 |
current_step += 1
|
| 645 |
-
print(f"Step {current_step}, loss: {loss.item():.4f}")
|
| 646 |
|
| 647 |
avg_loss = epoch_loss / len(train_loader)
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
-
|
| 650 |
-
|
| 651 |
|
| 652 |
# Save model
|
| 653 |
MODEL.eval()
|
|
@@ -658,19 +655,6 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 658 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 659 |
"config": CONFIG
|
| 660 |
}, save_path)
|
| 661 |
-
|
| 662 |
-
try:
|
| 663 |
-
from huggingface_hub import HfApi
|
| 664 |
-
api = HfApi()
|
| 665 |
-
api.upload_file(
|
| 666 |
-
path_or_fileobj=save_path,
|
| 667 |
-
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 668 |
-
repo_id="Spanicin/candlestick-diffusion",
|
| 669 |
-
repo_type="space"
|
| 670 |
-
)
|
| 671 |
-
logs.append("☁️ Checkpoint uploaded to repo")
|
| 672 |
-
except Exception as e:
|
| 673 |
-
logs.append(f"⚠️ Upload failed: {e}")
|
| 674 |
|
| 675 |
logs.append("-" * 40)
|
| 676 |
logs.append(f"✅ Model saved to {save_path}")
|
|
|
|
| 607 |
train_dataset = ChartDataset(data_path, image_size=image_size, split="train")
|
| 608 |
train_loader = DataLoader(
|
| 609 |
train_dataset, batch_size=batch_size, shuffle=True,
|
| 610 |
+
num_workers=2, pin_memory=True, drop_last=True, collate_fn=collate_fn
|
| 611 |
)
|
| 612 |
|
| 613 |
# Optimizer
|
|
|
|
| 642 |
|
| 643 |
epoch_loss += loss.item()
|
| 644 |
current_step += 1
|
|
|
|
| 645 |
|
| 646 |
avg_loss = epoch_loss / len(train_loader)
|
| 647 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
|
|
|
|
|
|
| 648 |
|
| 649 |
# Save model
|
| 650 |
MODEL.eval()
|
|
|
|
| 655 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 656 |
"config": CONFIG
|
| 657 |
}, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
logs.append("-" * 40)
|
| 660 |
logs.append(f"✅ Model saved to {save_path}")
|