infinity1096 commited on
Commit
b91a473
·
1 Parent(s): 7b12cd5

add spaces

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -11,6 +11,7 @@ import subprocess
11
  import sys
12
  import subprocess
13
  import importlib
 
14
 
15
  try:
16
  import uniception
@@ -38,10 +39,6 @@ from uniflowmatch.utils.viz import warp_image_with_flow
38
  model = None
39
  USE_REFINEMENT_MODEL = False
40
 
41
- use_gpu = torch.cuda.is_available()
42
- if use_gpu:
43
- print("Using GPU for processing.")
44
-
45
  def initialize_model(use_refinement: bool = False):
46
  """Initialize the model - call this once at startup"""
47
  global model, USE_REFINEMENT_MODEL
@@ -59,17 +56,13 @@ def initialize_model(use_refinement: bool = False):
59
  if hasattr(model, "eval"):
60
  model.eval()
61
 
62
- if use_gpu:
63
- print("Moving model to GPU...")
64
- model = model.to("cuda")
65
-
66
  print("Model loaded successfully!")
67
  return True
68
  except Exception as e:
69
  print(f"Error loading model: {e}")
70
  return False
71
 
72
-
73
  def process_images(source_image, target_image, model_type_choice):
74
  """
75
  Process two uploaded images and return visualizations
@@ -86,6 +79,9 @@ def process_images(source_image, target_image, model_type_choice):
86
  if model is None:
87
  return None, None, None, "Model not loaded. Please restart the application."
88
 
 
 
 
89
  try:
90
  # Convert PIL images to numpy arrays
91
  source_np = np.array(source_image)
 
11
  import sys
12
  import subprocess
13
  import importlib
14
+ import spaces
15
 
16
  try:
17
  import uniception
 
39
  model = None
40
  USE_REFINEMENT_MODEL = False
41
 
 
 
 
 
42
  def initialize_model(use_refinement: bool = False):
43
  """Initialize the model - call this once at startup"""
44
  global model, USE_REFINEMENT_MODEL
 
56
  if hasattr(model, "eval"):
57
  model.eval()
58
 
 
 
 
 
59
  print("Model loaded successfully!")
60
  return True
61
  except Exception as e:
62
  print(f"Error loading model: {e}")
63
  return False
64
 
65
+ @spaces.GPU
66
  def process_images(source_image, target_image, model_type_choice):
67
  """
68
  Process two uploaded images and return visualizations
 
79
  if model is None:
80
  return None, None, None, "Model not loaded. Please restart the application."
81
 
82
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
83
+ use_gpu = torch.cuda.is_available()
84
+
85
  try:
86
  # Convert PIL images to numpy arrays
87
  source_np = np.array(source_image)