Krishnakanth1993 commited on
Commit
73a0eba
·
verified ·
1 Parent(s): 109939b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -85,7 +85,8 @@ def load_models():
85
  """Load all models once and cache them globally."""
86
  global vae, tokenizer, text_encoder, unet, scheduler, device
87
 
88
- if vae is not None:
 
89
  return
90
 
91
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -93,20 +94,30 @@ def load_models():
93
 
94
  model_id = "CompVis/stable-diffusion-v1-4"
95
 
96
- print("Loading models...")
97
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
98
- tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
99
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
100
- unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
101
-
102
- scheduler = LMSDiscreteScheduler(
103
- beta_start=0.00085,
104
- beta_end=0.012,
105
- beta_schedule="scaled_linear",
106
- num_train_timesteps=1000
107
- )
108
-
109
- print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  def generate_with_style(
 
85
  """Load all models once and cache them globally."""
86
  global vae, tokenizer, text_encoder, unet, scheduler, device
87
 
88
+ # Check if already loaded
89
+ if vae is not None and scheduler is not None:
90
  return
91
 
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
94
 
95
  model_id = "CompVis/stable-diffusion-v1-4"
96
 
97
+ try:
98
+ print("Loading models... (this may take a few minutes on CPU)")
99
+
100
+ # Load with float16 on GPU, float32 on CPU
101
+ dtype = torch.float16 if device == "cuda" else torch.float32
102
+
103
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype).to(device)
104
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
105
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype).to(device)
106
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype).to(device)
107
+
108
+ # Initialize scheduler
109
+ scheduler = LMSDiscreteScheduler(
110
+ beta_start=0.00085,
111
+ beta_end=0.012,
112
+ beta_schedule="scaled_linear",
113
+ num_train_timesteps=1000
114
+ )
115
+
116
+ print("Models loaded successfully!")
117
+
118
+ except Exception as e:
119
+ print(f"Error loading models: {e}")
120
+ raise RuntimeError(f"Failed to load models: {e}")
121
 
122
 
123
  def generate_with_style(