Timtical commited on
Commit
5807b1c
·
verified ·
1 Parent(s): a53bf22

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +10 -11
train_model.py CHANGED
@@ -13,13 +13,16 @@ def train_model(
13
  output_dir: str,
14
  max_train_steps: int,
15
  learning_rate: float,
16
- hf_token: str
 
 
17
  ):
18
  try:
19
  gc.collect()
20
  if torch.cuda.is_available():
21
  torch.cuda.empty_cache()
22
- set_seed(42)
 
23
 
24
  instance_data_dir = Path("instance_data")
25
  if instance_data_dir.exists():
@@ -30,30 +33,26 @@ def train_model(
30
  print(f"✅ Data extracted to: {instance_data_dir}")
31
 
32
  model_id = "CompVis/stable-diffusion-v1-4"
 
 
 
33
  pipe = StableDiffusionPipeline.from_pretrained(
34
  model_id,
35
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
36
- revision="fp16" if torch.cuda.is_available() else "main",
37
  use_auth_token=hf_token
38
  )
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
  pipe.to(device)
42
 
43
- # This is where real training logic would go in production
44
  print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
45
  for step in range(int(max_train_steps)):
46
  if step % 100 == 0 or step == int(max_train_steps) - 1:
47
  print(f"Step {step + 1}/{max_train_steps}")
48
 
49
- # Save the model to the desired directory
50
  os.makedirs(output_dir, exist_ok=True)
51
  pipe.save_pretrained(output_dir)
52
-
53
- # Create a success flag file
54
- with open(os.path.join(output_dir, "training_complete.txt"), "w") as f:
55
- f.write("Training complete")
56
-
57
  return f"🎉 Training completed. Model saved to: {output_dir}"
58
 
59
  except Exception as e:
 
13
  output_dir: str,
14
  max_train_steps: int,
15
  learning_rate: float,
16
+ hf_token: str,
17
+ seed: int = 42,
18
+ precision: str = "fp16"
19
  ):
20
  try:
21
  gc.collect()
22
  if torch.cuda.is_available():
23
  torch.cuda.empty_cache()
24
+
25
+ set_seed(seed)
26
 
27
  instance_data_dir = Path("instance_data")
28
  if instance_data_dir.exists():
 
33
  print(f"✅ Data extracted to: {instance_data_dir}")
34
 
35
  model_id = "CompVis/stable-diffusion-v1-4"
36
+ torch_dtype = torch.float16 if precision == "fp16" else torch.float32
37
+ revision = "fp16" if precision == "fp16" else "main"
38
+
39
  pipe = StableDiffusionPipeline.from_pretrained(
40
  model_id,
41
+ torch_dtype=torch_dtype,
42
+ revision=revision,
43
  use_auth_token=hf_token
44
  )
45
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  pipe.to(device)
48
 
 
49
  print(f"🚧 Simulating training for {max_train_steps} steps at LR={learning_rate}")
50
  for step in range(int(max_train_steps)):
51
  if step % 100 == 0 or step == int(max_train_steps) - 1:
52
  print(f"Step {step + 1}/{max_train_steps}")
53
 
 
54
  os.makedirs(output_dir, exist_ok=True)
55
  pipe.save_pretrained(output_dir)
 
 
 
 
 
56
  return f"🎉 Training completed. Model saved to: {output_dir}"
57
 
58
  except Exception as e: