Timtical commited on
Commit
24767bd
·
verified ·
1 Parent(s): fb097b4

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +14 -0
train_model.py CHANGED
@@ -5,6 +5,16 @@ import gc
5
  from pathlib import Path
6
  from diffusers import StableDiffusionPipeline
7
  from accelerate.utils import set_seed
 
 
 
 
 
 
 
 
 
 
8
 
9
  def train_model(
10
  instance_token: str,
@@ -56,6 +66,10 @@ def train_model(
56
 
57
  os.makedirs(output_dir, exist_ok=True)
58
  pipe.save_pretrained(output_dir)
 
 
 
 
59
  return f"🎉 Training completed. Model saved to: {output_dir}"
60
 
61
  except Exception as e:
 
5
  from pathlib import Path
6
  from diffusers import StableDiffusionPipeline
7
  from accelerate.utils import set_seed
8
+ from huggingface_hub import HfApi, RepositoryNotFoundError, create_repo
9
+
10
+ def ensure_repo_exists(repo_id: str, hf_token: str):
11
+ api = HfApi()
12
+ try:
13
+ api.repo_info(repo_id, token=hf_token)
14
+ print(f"ℹ️ Repo '{repo_id}' already exists.")
15
+ except RepositoryNotFoundError:
16
+ create_repo(repo_id=repo_id, token=hf_token, repo_type="dataset", private=False)
17
+ print(f"✅ Repo '{repo_id}' created.")
18
 
19
  def train_model(
20
  instance_token: str,
 
66
 
67
  os.makedirs(output_dir, exist_ok=True)
68
  pipe.save_pretrained(output_dir)
69
+
70
+ # Ensure dataset repo exists on Hugging Face
71
+ ensure_repo_exists("generated-images", hf_token)
72
+
73
  return f"🎉 Training completed. Model saved to: {output_dir}"
74
 
75
  except Exception as e: