Spaces:
Runtime error
Runtime error
Update train_model.py
Browse files- train_model.py +24 -3
train_model.py
CHANGED
|
@@ -11,7 +11,28 @@ try:
|
|
| 11 |
except ImportError as e:
|
| 12 |
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
|
| 13 |
|
| 14 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def ensure_repo_exists(repo_id: str, hf_token: str):
|
| 17 |
api = HfApi()
|
|
@@ -38,6 +59,8 @@ def train_model(
|
|
| 38 |
dataset_repo_id: str = "generated-images"
|
| 39 |
):
|
| 40 |
try:
|
|
|
|
|
|
|
| 41 |
gc.collect()
|
| 42 |
if torch.cuda.is_available():
|
| 43 |
torch.cuda.empty_cache()
|
|
@@ -47,7 +70,6 @@ def train_model(
|
|
| 47 |
set_seed(seed)
|
| 48 |
print(f"π§ Using random seed: {seed}")
|
| 49 |
|
| 50 |
-
# Validate precision mode
|
| 51 |
if precision not in ["fp16", "fp32"]:
|
| 52 |
return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
|
| 53 |
|
|
@@ -81,7 +103,6 @@ def train_model(
|
|
| 81 |
os.makedirs(output_dir, exist_ok=True)
|
| 82 |
pipe.save_pretrained(output_dir)
|
| 83 |
|
| 84 |
-
# Ensure dataset repo exists on Hugging Face
|
| 85 |
ensure_repo_exists(dataset_repo_id, hf_token)
|
| 86 |
|
| 87 |
return f"π Training completed. Model saved to: {output_dir}"
|
|
|
|
| 11 |
except ImportError as e:
|
| 12 |
raise ImportError("huggingface_hub is missing or incompatible. Please ensure it's installed and up to date.") from e
|
| 13 |
|
| 14 |
+
# Optional: Version compatibility check
|
| 15 |
+
REQUIRED_HF_HUB_VERSION = "0.22.0"
|
| 16 |
+
REQUIRED_DIFFUSERS_VERSION = "0.25.0"
|
| 17 |
+
REQUIRED_ACCELERATE_VERSION = "0.27.2"
|
| 18 |
+
|
| 19 |
+
def check_versions():
|
| 20 |
+
import importlib.metadata as metadata
|
| 21 |
+
try:
|
| 22 |
+
hf_hub_version = metadata.version("huggingface_hub")
|
| 23 |
+
diffusers_version = metadata.version("diffusers")
|
| 24 |
+
accelerate_version = metadata.version("accelerate")
|
| 25 |
+
|
| 26 |
+
print(f"π Versions: huggingface_hub={hf_hub_version}, diffusers={diffusers_version}, accelerate={accelerate_version}")
|
| 27 |
+
|
| 28 |
+
if hf_hub_version < REQUIRED_HF_HUB_VERSION:
|
| 29 |
+
raise RuntimeError(f"huggingface_hub must be >= {REQUIRED_HF_HUB_VERSION}")
|
| 30 |
+
if diffusers_version < REQUIRED_DIFFUSERS_VERSION:
|
| 31 |
+
raise RuntimeError(f"diffusers must be >= {REQUIRED_DIFFUSERS_VERSION}")
|
| 32 |
+
if accelerate_version < REQUIRED_ACCELERATE_VERSION:
|
| 33 |
+
raise RuntimeError(f"accelerate must be >= {REQUIRED_ACCELERATE_VERSION}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
raise RuntimeError(f"β Version check failed: {e}")
|
| 36 |
|
| 37 |
def ensure_repo_exists(repo_id: str, hf_token: str):
|
| 38 |
api = HfApi()
|
|
|
|
| 59 |
dataset_repo_id: str = "generated-images"
|
| 60 |
):
|
| 61 |
try:
|
| 62 |
+
check_versions()
|
| 63 |
+
|
| 64 |
gc.collect()
|
| 65 |
if torch.cuda.is_available():
|
| 66 |
torch.cuda.empty_cache()
|
|
|
|
| 70 |
set_seed(seed)
|
| 71 |
print(f"π§ Using random seed: {seed}")
|
| 72 |
|
|
|
|
| 73 |
if precision not in ["fp16", "fp32"]:
|
| 74 |
return f"β Training failed: Invalid precision mode '{precision}'. Choose 'fp16' or 'fp32'."
|
| 75 |
|
|
|
|
| 103 |
os.makedirs(output_dir, exist_ok=True)
|
| 104 |
pipe.save_pretrained(output_dir)
|
| 105 |
|
|
|
|
| 106 |
ensure_repo_exists(dataset_repo_id, hf_token)
|
| 107 |
|
| 108 |
return f"π Training completed. Model saved to: {output_dir}"
|