bh4vay commited on
Commit
f232606
·
verified ·
1 Parent(s): cbd648d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -19,8 +19,8 @@ def load_text_model():
19
  # Load model with FP16 or 8-bit quantization
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
22
- torch_dtype=torch.float16 if device == "cuda" else torch.float32, # Reduce VRAM usage
23
- low_cpu_mem_usage=True # Optimize memory
24
  ).to(device)
25
 
26
  st.write("✅ Text model loaded successfully!")
@@ -40,7 +40,7 @@ def load_image_model():
40
  model_id = "runwayml/stable-diffusion-v1-5"
41
  model = StableDiffusionPipeline.from_pretrained(
42
  model_id,
43
- torch_dtype=torch.float16 if device == "cuda" else torch.float32 # Reduce VRAM usage
44
  ).to(device)
45
  model.enable_attention_slicing() # Optimize GPU memory
46
  st.write("✅ Image model loaded successfully!")
 
19
  # Load model with FP16 or 8-bit quantization
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
22
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
23
+ low_cpu_mem_usage=True
24
  ).to(device)
25
 
26
  st.write("✅ Text model loaded successfully!")
 
40
  model_id = "runwayml/stable-diffusion-v1-5"
41
  model = StableDiffusionPipeline.from_pretrained(
42
  model_id,
43
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
44
  ).to(device)
45
  model.enable_attention_slicing() # Optimize GPU memory
46
  st.write("✅ Image model loaded successfully!")