Timtical commited on
Commit
0003569
Β·
verified Β·
1 Parent(s): 7aa015e

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +24 -3
train_model.py CHANGED
@@ -11,7 +11,28 @@ try:
11
  except ImportError as e:
12
  raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
13
 
14
- # TODO: Validate version compatibility for diffusers, accelerate, and huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def ensure_repo_exists(repo_id: str, hf_token: str):
17
  api = HfApi()
@@ -38,6 +59,8 @@ def train_model(
38
  dataset_repo_id: str = "generated-images"
39
  ):
40
  try:
 
 
41
  gc.collect()
42
  if torch.cuda.is_available():
43
  torch.cuda.empty_cache()
@@ -47,7 +70,6 @@ def train_model(
47
  set_seed(seed)
48
  print(f"πŸ”§ Using random seed: {seed}")
49
 
50
- # Validate precision mode
51
  if precision not in ["fp16", "fp32"]:
52
  return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
53
 
@@ -81,7 +103,6 @@ def train_model(
81
  os.makedirs(output_dir, exist_ok=True)
82
  pipe.save_pretrained(output_dir)
83
 
84
- # Ensure dataset repo exists on Hugging Face
85
  ensure_repo_exists(dataset_repo_id, hf_token)
86
 
87
  return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"
 
11
  except ImportError as e:
12
  raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
13
 
14
+ # Optional: Version compatibility check
15
+ REQUIRED_HF_HUB_VERSION = "0.22.0"
16
+ REQUIRED_DIFFUSERS_VERSION = "0.25.0"
17
+ REQUIRED_ACCELERATE_VERSION = "0.27.2"
18
+
19
+ def check_versions():
20
+ import importlib.metadata as metadata
21
+ try:
22
+ hf_hub_version = metadata.version("huggingface_hub")
23
+ diffusers_version = metadata.version("diffusers")
24
+ accelerate_version = metadata.version("accelerate")
25
+
26
+ print(f"πŸ” Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
27
+
28
+ if hf_hub_version < REQUIRED_HF_HUB_VERSION:
29
+ raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
30
+ if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
31
+ raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
32
+ if accelerate_version < REQUIRED_ACCELERATE_VERSION:
33
+ raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
34
+ except Exception as e:
35
+ raise RuntimeError(f"❌ Version check failed: {e}")
36
 
37
  def ensure_repo_exists(repo_id: str, hf_token: str):
38
  api = HfApi()
 
59
  dataset_repo_id: str = "generated-images"
60
  ):
61
  try:
62
+ check_versions()
63
+
64
  gc.collect()
65
  if torch.cuda.is_available():
66
  torch.cuda.empty_cache()
 
70
  set_seed(seed)
71
  print(f"πŸ”§ Using random seed: {seed}")
72
 
 
73
  if precision not in ["fp16", "fp32"]:
74
  return f"❌ Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
75
 
 
103
  os.makedirs(output_dir, exist_ok=True)
104
  pipe.save_pretrained(output_dir)
105
 
 
106
  ensure_repo_exists(dataset_repo_id, hf_token)
107
 
108
  return f"πŸŽ‰ Training completed. Model saved to: {output_dir}"