Commit
·
05a48c5
1
Parent(s):
1452c34
Made functions use gpu
Browse files
app.py
CHANGED
|
@@ -68,16 +68,18 @@ def init_model():
|
|
| 68 |
dtype=torch.float32
|
| 69 |
)
|
| 70 |
|
| 71 |
-
return model, pipeline
|
| 72 |
|
| 73 |
# Global variables to store model and pipeline
|
| 74 |
model = None
|
| 75 |
pipeline = None
|
| 76 |
|
|
|
|
| 77 |
def get_model():
|
| 78 |
-
|
|
|
|
| 79 |
if model is None or pipeline is None:
|
| 80 |
-
model, pipeline = init_model()
|
| 81 |
return model, pipeline
|
| 82 |
|
| 83 |
rembg_session = rembg.new_session()
|
|
@@ -144,6 +146,7 @@ def add_random_background(image, color):
|
|
| 144 |
background = Image.new("RGBA", image.size, color)
|
| 145 |
return Image.alpha_composite(background, image)
|
| 146 |
|
|
|
|
| 147 |
def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
|
| 148 |
"""Preprocess the input image"""
|
| 149 |
try:
|
|
@@ -169,6 +172,7 @@ def preprocess_image(input_image, background_choice, foreground_ratio, back_grou
|
|
| 169 |
print(f"Error in preprocess_image: {str(e)}")
|
| 170 |
raise e
|
| 171 |
|
|
|
|
| 172 |
def gen_image(processed_image, seed, scale, step):
|
| 173 |
"""Generate the 3D model"""
|
| 174 |
try:
|
|
|
|
| 68 |
dtype=torch.float32
|
| 69 |
)
|
| 70 |
|
| 71 |
+
return model, pipeline, args
|
| 72 |
|
| 73 |
# Global variables to store model and pipeline
|
| 74 |
model = None
|
| 75 |
pipeline = None
|
| 76 |
|
| 77 |
+
@spaces.GPU
|
| 78 |
def get_model():
|
| 79 |
+
"""Lazy initialization of model and pipeline"""
|
| 80 |
+
global model, pipeline, args
|
| 81 |
if model is None or pipeline is None:
|
| 82 |
+
model, pipeline, args = init_model()
|
| 83 |
return model, pipeline
|
| 84 |
|
| 85 |
rembg_session = rembg.new_session()
|
|
|
|
| 146 |
background = Image.new("RGBA", image.size, color)
|
| 147 |
return Image.alpha_composite(background, image)
|
| 148 |
|
| 149 |
+
@spaces.GPU
|
| 150 |
def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
|
| 151 |
"""Preprocess the input image"""
|
| 152 |
try:
|
|
|
|
| 172 |
print(f"Error in preprocess_image: {str(e)}")
|
| 173 |
raise e
|
| 174 |
|
| 175 |
+
@spaces.GPU
|
| 176 |
def gen_image(processed_image, seed, scale, step):
|
| 177 |
"""Generate the 3D model"""
|
| 178 |
try:
|