jahidularafat commited on
Commit
6006853
·
verified ·
1 Parent(s): b25451e

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +13 -14
  2. model.onnx +3 -0
  3. requirements.txt +5 -1
  4. runtime.txt +1 -0
  5. scripts/gradio_app.py +77 -0
app.py CHANGED
@@ -1,14 +1,13 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return f"Hello, {name}! 👋"
5
-
6
- with gr.Blocks() as demo:
7
- gr.Markdown("# cspws-space")
8
- name = gr.Textbox(label="Your name")
9
- out = gr.Textbox(label="Greeting")
10
- btn = gr.Button("Say hi")
11
- btn.click(fn=greet, inputs=name, outputs=out)
12
-
13
- if __name__ == "__main__":
14
- demo.launch()
 
1
+ import os, sys
2
+ from pathlib import Path
3
+ if (Path("scripts") / "gradio_app.py").exists():
4
+ sys.path.append(str(Path("scripts").absolute()))
5
+ from gradio_app import main
6
+ if __name__ == "__main__":
7
+ os.system("python scripts/gradio_app.py --onnx model.onnx --channels 3 --img-size 256 --port 7860")
8
+ else:
9
+ import gradio as gr
10
+ def hello(x): return x
11
+ demo = gr.Interface(fn=hello, inputs="image", outputs="image", title="csPWS Fallback App")
12
+ if __name__ == "__main__":
13
+ demo.launch(server_port=7860)
 
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abe3d43a2d19a4b8eb46d7cbf591c63541b8085c00437cff30cde8797bcbc1d5
3
+ size 39925272
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- gradio
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ onnxruntime>=1.18.0
3
+ onnx>=1.15.0
4
+ numpy
5
+ pillow
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.12
scripts/gradio_app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse, os, numpy as np, torch
3
+ from PIL import Image
4
+ import torchvision.transforms.functional as F
5
+ import onnxruntime as ort
6
+
7
+ def pick_providers():
8
+ avail = ort.get_available_providers()
9
+ return (["CoreMLExecutionProvider","CPUExecutionProvider"]
10
+ if "CoreMLExecutionProvider" in avail else ["CPUExecutionProvider"])
11
+
12
+ def load_torch_model(ckpt_path, channels=3, img_size=256):
13
+ from cspws_seg.models.unet_attn import UNetAttn
14
+ device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends,'mps') and torch.backends.mps.is_available()) else "cpu"))
15
+ model = UNetAttn(in_ch=channels, out_ch=1).to(device)
16
+ state = torch.load(ckpt_path, map_location=device)
17
+ model.load_state_dict(state.get("model", state)); model.eval()
18
+ return model, device, img_size
19
+
20
+ def predict_torch(model, device, img: Image.Image, img_size=256):
21
+ img = img.convert("RGB")
22
+ img = F.resize(img, [img_size, img_size], antialias=True)
23
+ t = F.to_tensor(img).unsqueeze(0).to(device)
24
+ with torch.inference_mode():
25
+ logits = model(t)[0,0].detach().cpu().numpy()
26
+ mask = (1/(1+np.exp(-logits)) > 0.5).astype(np.uint8)*255
27
+ return Image.fromarray(mask)
28
+
29
+ def load_onnx_session(onnx_path):
30
+ prov = pick_providers()
31
+ sess = ort.InferenceSession(onnx_path, providers=prov)
32
+ input_name = sess.get_inputs()[0].name
33
+ return sess, input_name
34
+
35
+ def predict_onnx(sess, input_name, img: Image.Image, img_size=256):
36
+ img = img.convert("RGB").resize((img_size,img_size))
37
+ x = np.transpose(np.array(img, dtype=np.float32)/255.0, (2,0,1))[None,...]
38
+ logits = sess.run(None, {input_name: x})[0][0,0]
39
+ mask = (1/(1+np.exp(-logits)) > 0.5).astype(np.uint8)*255
40
+ return Image.fromarray(mask)
41
+
42
+ def main():
43
+ import gradio as gr
44
+ ap = argparse.ArgumentParser()
45
+ ap.add_argument("--ckpt", default="runs/cspws_unet_attn/best.pt")
46
+ ap.add_argument("--onnx", default=None)
47
+ ap.add_argument("--img-size", type=int, default=256)
48
+ ap.add_argument("--port", type=int, default=7860)
49
+ ap.add_argument("--share", action="store_true")
50
+ args = ap.parse_args()
51
+
52
+ if args.onnx and os.path.exists(args.onnx):
53
+ sess, in_name = load_onnx_session(args.onnx)
54
+ def _fn(img): return predict_onnx(sess, in_name, img, args.img_size)
55
+ title = "csPWS Segmentation — ONNX (CoreML/CPU auto)"
56
+ elif os.path.exists(args.ckpt):
57
+ model, device, img_size = load_torch_model(args.ckpt, 3, args.img_size)
58
+ def _fn(img): return predict_torch(model, device, img, img_size)
59
+ title = "csPWS Segmentation — PyTorch"
60
+ else:
61
+ raise FileNotFoundError(
62
+ f"No weights found. Expected either:\n"
63
+ f" - PyTorch ckpt at {args.ckpt}\n"
64
+ f" - OR ONNX at {args.onnx or 'exports/model.onnx'}\n"
65
+ )
66
+
67
+ demo = gr.Interface(
68
+ fn=_fn,
69
+ inputs=gr.Image(type="pil", label="Upload csPWS composite (Σ/BF/SW)"),
70
+ outputs=gr.Image(type="pil", label="Predicted Mask"),
71
+ title=title,
72
+ allow_flagging="never"
73
+ )
74
+ demo.launch(server_port=args.port, share=args.share)
75
+
76
+ if __name__ == "__main__":
77
+ main()