Timtical commited on
Commit
b02edaf
·
verified ·
1 Parent(s): 6b73d37

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +9 -12
train_model.py CHANGED
@@ -16,12 +16,11 @@ def train_model(
16
  hf_token: str
17
  ):
18
  try:
19
- # Memory and seed setup
20
  gc.collect()
21
- torch.cuda.empty_cache()
 
22
  set_seed(42)
23
 
24
- # Unzip uploaded data
25
  instance_data_dir = Path("instance_data")
26
  if instance_data_dir.exists():
27
  shutil.rmtree(instance_data_dir)
@@ -30,26 +29,24 @@ def train_model(
30
 
31
  print(f"✅ Data extracted to: {instance_data_dir}")
32
 
33
- # Load pre-trained model
34
  model_id = "CompVis/stable-diffusion-v1-4"
35
  pipe = StableDiffusionPipeline.from_pretrained(
36
  model_id,
37
- torch_dtype=torch.float16,
38
- revision="fp16",
39
  use_auth_token=hf_token
40
  )
41
- pipe.to("cuda")
42
 
43
- # Placeholder training simulation
44
- print(f"🚧 Starting training for {max_train_steps} steps with LR={learning_rate}...")
 
 
45
  for step in range(int(max_train_steps)):
46
  if step % 100 == 0 or step == int(max_train_steps) - 1:
47
- print(f"Training step {step + 1}/{max_train_steps}")
48
 
49
- # Save model
50
  os.makedirs(output_dir, exist_ok=True)
51
  pipe.save_pretrained(output_dir)
52
-
53
  return f"🎉 Training completed. Model saved to: {output_dir}"
54
 
55
  except Exception as e:
 
16
  hf_token: str
17
  ):
18
  try:
 
19
  gc.collect()
20
+ if torch.cuda.is_available():
21
+ torch.cuda.empty_cache()
22
  set_seed(42)
23
 
 
24
  instance_data_dir = Path("instance_data")
25
  if instance_data_dir.exists():
26
  shutil.rmtree(instance_data_dir)
 
29
 
30
  print(f"✅ Data extracted to: {instance_data_dir}")
31
 
 
32
  model_id = "CompVis/stable-diffusion-v1-4"
33
  pipe = StableDiffusionPipeline.from_pretrained(
34
  model_id,
35
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
36
+ revision="fp16" if torch.cuda.is_available() else "main",
37
  use_auth_token=hf_token
38
  )
 
39
 
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ pipe.to(device)
42
+
43
+ print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
44
  for step in range(int(max_train_steps)):
45
  if step % 100 == 0 or step == int(max_train_steps) - 1:
46
+ print(f"Step {step + 1}/{max_train_steps}")
47
 
 
48
  os.makedirs(output_dir, exist_ok=True)
49
  pipe.save_pretrained(output_dir)
 
50
  return f"🎉 Training completed. Model saved to: {output_dir}"
51
 
52
  except Exception as e: