Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -89,22 +89,7 @@ u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
|
|
| 89 |
yolo_world_global = None
|
| 90 |
yolo_world_model_path = os.path.join(CACHE_DIR, "yolov8s_world.pt") # Adjust path as needed
|
| 91 |
|
| 92 |
-
|
| 93 |
-
"""Lazy load YOLOWorld model"""
|
| 94 |
-
global yolo_world_global
|
| 95 |
-
if yolo_world_global is None:
|
| 96 |
-
logger.info("Loading YOLOWorld model...")
|
| 97 |
-
if os.path.exists(yolo_world_model_path):
|
| 98 |
-
try:
|
| 99 |
-
yolo_world_global = YOLOWorld(yolo_world_model_path)
|
| 100 |
-
logger.info("YOLOWorld model loaded successfully")
|
| 101 |
-
except Exception as e:
|
| 102 |
-
logger.error(f"Failed to load YOLOWorld: {e}")
|
| 103 |
-
yolo_world_global = None
|
| 104 |
-
else:
|
| 105 |
-
logger.warning("YOLOWorld model file not found, will raise error if used")
|
| 106 |
-
yolo_world_global = None
|
| 107 |
-
return yolo_world_global
|
| 108 |
|
| 109 |
# Device configuration
|
| 110 |
device = "cpu"
|
|
@@ -123,6 +108,11 @@ def ensure_model_files():
|
|
| 123 |
shutil.copy("u2netp.pth", u2net_model_path)
|
| 124 |
else:
|
| 125 |
raise FileNotFoundError("u2netp.pth model file not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
ensure_model_files()
|
| 128 |
|
|
@@ -144,7 +134,22 @@ def get_paper_detector():
|
|
| 144 |
logger.warning("Paper model file not found, using fallback detection")
|
| 145 |
paper_detector_global = None
|
| 146 |
return paper_detector_global
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
def get_u2net():
|
| 149 |
"""Lazy load U2NETP model"""
|
| 150 |
global u2net_global
|
|
|
|
| 89 |
yolo_world_global = None
|
| 90 |
yolo_world_model_path = os.path.join(CACHE_DIR, "yolov8s_world.pt") # Adjust path as needed
|
| 91 |
|
| 92 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Device configuration
|
| 95 |
device = "cpu"
|
|
|
|
| 108 |
shutil.copy("u2netp.pth", u2net_model_path)
|
| 109 |
else:
|
| 110 |
raise FileNotFoundError("u2netp.pth model file not found")
|
| 111 |
+
if not os.path.exists(yolo_world_model_path):
|
| 112 |
+
if os.path.exists("yolov8s-world.pt"): # Adjust to match your file name
|
| 113 |
+
shutil.copy("yolov8s-world.pt", yolo_world_model_path)
|
| 114 |
+
else:
|
| 115 |
+
logger.warning("yolov8s-world.pt model file not found - falling back to full image processing")
|
| 116 |
|
| 117 |
ensure_model_files()
|
| 118 |
|
|
|
|
| 134 |
logger.warning("Paper model file not found, using fallback detection")
|
| 135 |
paper_detector_global = None
|
| 136 |
return paper_detector_global
|
| 137 |
+
def get_yolo_world():
|
| 138 |
+
"""Lazy load YOLOWorld model"""
|
| 139 |
+
global yolo_world_global
|
| 140 |
+
if yolo_world_global is None:
|
| 141 |
+
logger.info("Loading YOLOWorld model...")
|
| 142 |
+
if os.path.exists(yolo_world_model_path):
|
| 143 |
+
try:
|
| 144 |
+
yolo_world_global = YOLOWorld(yolo_world_model_path)
|
| 145 |
+
logger.info("YOLOWorld model loaded successfully")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Failed to load YOLOWorld: {e}")
|
| 148 |
+
yolo_world_global = None
|
| 149 |
+
else:
|
| 150 |
+
logger.warning("YOLOWorld model file not found, will raise error if used")
|
| 151 |
+
yolo_world_global = None
|
| 152 |
+
return yolo_world_global
|
| 153 |
def get_u2net():
|
| 154 |
"""Lazy load U2NETP model"""
|
| 155 |
global u2net_global
|