Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -648,12 +648,26 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
|
| 650 |
if (epoch + 1) % 10 == 0:
|
|
|
|
| 651 |
torch.save({
|
| 652 |
"model_state_dict": MODEL.state_dict(),
|
| 653 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 654 |
"config": CONFIG
|
| 655 |
-
},
|
| 656 |
logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
# Save model
|
| 659 |
MODEL.eval()
|
|
|
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
|
| 650 |
if (epoch + 1) % 10 == 0:
|
| 651 |
+
ckpt_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
|
| 652 |
torch.save({
|
| 653 |
"model_state_dict": MODEL.state_dict(),
|
| 654 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 655 |
"config": CONFIG
|
| 656 |
+
}, ckpt_path)
|
| 657 |
logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
|
| 658 |
+
try:
|
| 659 |
+
from huggingface_hub import HfApi
|
| 660 |
+
api = HfApi()
|
| 661 |
+
api.upload_file(
|
| 662 |
+
path_or_fileobj=ckpt_path,
|
| 663 |
+
path_in_repo=ckpt_path,
|
| 664 |
+
repo_id="Spanicin/candlestick-diffusion",
|
| 665 |
+
repo_type="space",
|
| 666 |
+
token=os.environ.get("HF_TOKEN")
|
| 667 |
+
)
|
| 668 |
+
logs.append(f"☁️ Uploaded to repo")
|
| 669 |
+
except Exception as e:
|
| 670 |
+
logs.append(f"⚠️ Upload failed: {e}")
|
| 671 |
|
| 672 |
# Save model
|
| 673 |
MODEL.eval()
|