Update app.py
Browse files
app.py
CHANGED
|
@@ -19,8 +19,8 @@ def load_text_model():
|
|
| 19 |
# Load model with FP16 or 8-bit quantization
|
| 20 |
model = AutoModelForCausalLM.from_pretrained(
|
| 21 |
model_name,
|
| 22 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 23 |
-
low_cpu_mem_usage=True
|
| 24 |
).to(device)
|
| 25 |
|
| 26 |
st.write("✅ Text model loaded successfully!")
|
|
@@ -40,7 +40,7 @@ def load_image_model():
|
|
| 40 |
model_id = "runwayml/stable-diffusion-v1-5"
|
| 41 |
model = StableDiffusionPipeline.from_pretrained(
|
| 42 |
model_id,
|
| 43 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 44 |
).to(device)
|
| 45 |
model.enable_attention_slicing() # Optimize GPU memory
|
| 46 |
st.write("✅ Image model loaded successfully!")
|
|
|
|
| 19 |
# Load model with FP16 or 8-bit quantization
|
| 20 |
model = AutoModelForCausalLM.from_pretrained(
|
| 21 |
model_name,
|
| 22 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 23 |
+
low_cpu_mem_usage=True
|
| 24 |
).to(device)
|
| 25 |
|
| 26 |
st.write("✅ Text model loaded successfully!")
|
|
|
|
| 40 |
model_id = "runwayml/stable-diffusion-v1-5"
|
| 41 |
model = StableDiffusionPipeline.from_pretrained(
|
| 42 |
model_id,
|
| 43 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 44 |
).to(device)
|
| 45 |
model.enable_attention_slicing() # Optimize GPU memory
|
| 46 |
st.write("✅ Image model loaded successfully!")
|