Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- app.py +13 -14
- model.onnx +3 -0
- requirements.txt +5 -1
- runtime.txt +1 -0
- scripts/gradio_app.py +77 -0
app.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
if (Path("scripts") / "gradio_app.py").exists():
|
| 4 |
+
sys.path.append(str(Path("scripts").absolute()))
|
| 5 |
+
from gradio_app import main
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
os.system("python scripts/gradio_app.py --onnx model.onnx --channels 3 --img-size 256 --port 7860")
|
| 8 |
+
else:
|
| 9 |
+
import gradio as gr
|
| 10 |
+
def hello(x): return x
|
| 11 |
+
demo = gr.Interface(fn=hello, inputs="image", outputs="image", title="csPWS Fallback App")
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
demo.launch(server_port=7860)
|
|
|
model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:abe3d43a2d19a4b8eb46d7cbf591c63541b8085c00437cff30cde8797bcbc1d5
|
| 3 |
+
size 39925272
|
requirements.txt
CHANGED
|
@@ -1 +1,5 @@
|
|
| 1 |
-
gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
onnxruntime>=1.18.0
|
| 3 |
+
onnx>=1.15.0
|
| 4 |
+
numpy
|
| 5 |
+
pillow
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.12
|
scripts/gradio_app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse, os, numpy as np, torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torchvision.transforms.functional as F
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
|
| 7 |
+
def pick_providers():
|
| 8 |
+
avail = ort.get_available_providers()
|
| 9 |
+
return (["CoreMLExecutionProvider","CPUExecutionProvider"]
|
| 10 |
+
if "CoreMLExecutionProvider" in avail else ["CPUExecutionProvider"])
|
| 11 |
+
|
| 12 |
+
def load_torch_model(ckpt_path, channels=3, img_size=256):
|
| 13 |
+
from cspws_seg.models.unet_attn import UNetAttn
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends,'mps') and torch.backends.mps.is_available()) else "cpu"))
|
| 15 |
+
model = UNetAttn(in_ch=channels, out_ch=1).to(device)
|
| 16 |
+
state = torch.load(ckpt_path, map_location=device)
|
| 17 |
+
model.load_state_dict(state.get("model", state)); model.eval()
|
| 18 |
+
return model, device, img_size
|
| 19 |
+
|
| 20 |
+
def predict_torch(model, device, img: Image.Image, img_size=256):
|
| 21 |
+
img = img.convert("RGB")
|
| 22 |
+
img = F.resize(img, [img_size, img_size], antialias=True)
|
| 23 |
+
t = F.to_tensor(img).unsqueeze(0).to(device)
|
| 24 |
+
with torch.inference_mode():
|
| 25 |
+
logits = model(t)[0,0].detach().cpu().numpy()
|
| 26 |
+
mask = (1/(1+np.exp(-logits)) > 0.5).astype(np.uint8)*255
|
| 27 |
+
return Image.fromarray(mask)
|
| 28 |
+
|
| 29 |
+
def load_onnx_session(onnx_path):
|
| 30 |
+
prov = pick_providers()
|
| 31 |
+
sess = ort.InferenceSession(onnx_path, providers=prov)
|
| 32 |
+
input_name = sess.get_inputs()[0].name
|
| 33 |
+
return sess, input_name
|
| 34 |
+
|
| 35 |
+
def predict_onnx(sess, input_name, img: Image.Image, img_size=256):
|
| 36 |
+
img = img.convert("RGB").resize((img_size,img_size))
|
| 37 |
+
x = np.transpose(np.array(img, dtype=np.float32)/255.0, (2,0,1))[None,...]
|
| 38 |
+
logits = sess.run(None, {input_name: x})[0][0,0]
|
| 39 |
+
mask = (1/(1+np.exp(-logits)) > 0.5).astype(np.uint8)*255
|
| 40 |
+
return Image.fromarray(mask)
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
import gradio as gr
|
| 44 |
+
ap = argparse.ArgumentParser()
|
| 45 |
+
ap.add_argument("--ckpt", default="runs/cspws_unet_attn/best.pt")
|
| 46 |
+
ap.add_argument("--onnx", default=None)
|
| 47 |
+
ap.add_argument("--img-size", type=int, default=256)
|
| 48 |
+
ap.add_argument("--port", type=int, default=7860)
|
| 49 |
+
ap.add_argument("--share", action="store_true")
|
| 50 |
+
args = ap.parse_args()
|
| 51 |
+
|
| 52 |
+
if args.onnx and os.path.exists(args.onnx):
|
| 53 |
+
sess, in_name = load_onnx_session(args.onnx)
|
| 54 |
+
def _fn(img): return predict_onnx(sess, in_name, img, args.img_size)
|
| 55 |
+
title = "csPWS Segmentation — ONNX (CoreML/CPU auto)"
|
| 56 |
+
elif os.path.exists(args.ckpt):
|
| 57 |
+
model, device, img_size = load_torch_model(args.ckpt, 3, args.img_size)
|
| 58 |
+
def _fn(img): return predict_torch(model, device, img, img_size)
|
| 59 |
+
title = "csPWS Segmentation — PyTorch"
|
| 60 |
+
else:
|
| 61 |
+
raise FileNotFoundError(
|
| 62 |
+
f"No weights found. Expected either:\n"
|
| 63 |
+
f" - PyTorch ckpt at {args.ckpt}\n"
|
| 64 |
+
f" - OR ONNX at {args.onnx or 'exports/model.onnx'}\n"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
demo = gr.Interface(
|
| 68 |
+
fn=_fn,
|
| 69 |
+
inputs=gr.Image(type="pil", label="Upload csPWS composite (Σ/BF/SW)"),
|
| 70 |
+
outputs=gr.Image(type="pil", label="Predicted Mask"),
|
| 71 |
+
title=title,
|
| 72 |
+
allow_flagging="never"
|
| 73 |
+
)
|
| 74 |
+
demo.launch(server_port=args.port, share=args.share)
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|