Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,8 +119,8 @@ def unload_and_reload_models():
|
|
| 119 |
new_birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 120 |
"zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
|
| 121 |
)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
# new_u2net = U2NETP(3, 1)
|
| 125 |
# new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
|
| 126 |
# new_u2net.to(device)
|
|
@@ -455,13 +455,10 @@ def predict(
|
|
| 455 |
except Exception:
|
| 456 |
raise ValueError("Invalid base64 image data")
|
| 457 |
|
| 458 |
-
#
|
| 459 |
if isinstance(image, np.ndarray):
|
| 460 |
pil_image = Image.fromarray(image)
|
| 461 |
-
|
| 462 |
-
dark_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
|
| 463 |
-
# Enhance sharpness.
|
| 464 |
-
enhanced_image = ImageEnhance.Sharpness(dark_image).enhance(1)
|
| 465 |
image = np.array(enhanced_image)
|
| 466 |
|
| 467 |
# ---------------------
|
|
@@ -682,13 +679,6 @@ def predict(
|
|
| 682 |
cv2.LINE_AA
|
| 683 |
)
|
| 684 |
|
| 685 |
-
# Restore brightness for display purposes:
|
| 686 |
-
# Since we reduced brightness by 0.5 during preprocessing,
|
| 687 |
-
# we apply an enhancement factor of 2.0 here to bring it back.
|
| 688 |
-
display_img = Image.fromarray(output_img)
|
| 689 |
-
display_img = ImageEnhance.Brightness(display_img).enhance(2.0)
|
| 690 |
-
output_img = np.array(display_img)
|
| 691 |
-
|
| 692 |
outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
|
| 693 |
print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
|
| 694 |
|
|
|
|
| 119 |
new_birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 120 |
"zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
|
| 121 |
)
|
| 122 |
+
new_birefnet.to(device)
|
| 123 |
+
new_birefnet.eval()
|
| 124 |
# new_u2net = U2NETP(3, 1)
|
| 125 |
# new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
|
| 126 |
# new_u2net.to(device)
|
|
|
|
| 455 |
except Exception:
|
| 456 |
raise ValueError("Invalid base64 image data")
|
| 457 |
|
| 458 |
+
# Apply brightness and sharpness enhancement.
|
| 459 |
if isinstance(image, np.ndarray):
|
| 460 |
pil_image = Image.fromarray(image)
|
| 461 |
+
enhanced_image = ImageEnhance.Sharpness(Bright).enhance(0.5)
|
|
|
|
|
|
|
|
|
|
| 462 |
image = np.array(enhanced_image)
|
| 463 |
|
| 464 |
# ---------------------
|
|
|
|
| 679 |
cv2.LINE_AA
|
| 680 |
)
|
| 681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
|
| 683 |
print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
|
| 684 |
|