Annanya2306 commited on
Commit
4589763
·
verified ·
1 Parent(s): 34ce190

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Streamlit + WebRTC (Hugging Face Spaces ready)
2
+ import io, numpy as np, torch, torchvision.transforms as T
3
+ from torchvision import models
4
+ from PIL import Image
5
+ import streamlit as st
6
+ import mediapipe as mp
7
+ import cv2
8
+
9
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
10
+ import av # needs ffmpeg + pkg-config via packages.txt
11
+
12
+ st.set_page_config(page_title="Mask Detection (Webcam)", layout="wide")
13
+ st.title("😷 Face Mask Detection — Webcam + Image (HF Spaces)")
14
+
15
+ LABELS = ["mask", "no_mask"]
16
+ IMG_SIZE = 224
17
+ MEAN = [0.485, 0.456, 0.406]
18
+ STD = [0.229, 0.224, 0.225]
19
+
20
+ @st.cache_resource
21
+ def load_model(weights_path: str):
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model = models.mobilenet_v2(weights=None)
24
+ model.classifier[1] = torch.nn.Linear(model.last_channel, len(LABELS))
25
+ state = torch.load(weights_path, map_location="cpu")
26
+ model.load_state_dict(state, strict=True)
27
+ model = model.to(device).eval()
28
+ return model, device
29
+
30
+ @st.cache_resource
31
+ def get_tf():
32
+ return T.Compose([
33
+ T.Resize((IMG_SIZE, IMG_SIZE)),
34
+ T.ToTensor(),
35
+ T.Normalize(MEAN, STD),
36
+ ])
37
+
38
+ def predict_pil(pil_img, model, device):
39
+ x = get_tf()(pil_img.convert("RGB")).unsqueeze(0).to(device)
40
+ with torch.no_grad():
41
+ probs = torch.softmax(model(x), dim=1)[0].cpu().numpy()
42
+ i = int(np.argmax(probs))
43
+ return LABELS[i], float(probs[i]), probs
44
+
45
+ mp_fd = mp.solutions.face_detection
46
+ @st.cache_resource
47
+ def get_detector():
48
+ return mp_fd.FaceDetection(model_selection=0, min_detection_confidence=0.5)
49
+
50
+ def expand_box(x, y, w, h, scale, W, H):
51
+ cx, cy = x + w/2, y + h/2
52
+ nw, nh = w*scale, h*scale
53
+ x1 = int(max(0, cx - nw/2)); y1 = int(max(0, cy - nh/2))
54
+ x2 = int(min(W, cx + nw/2)); y2 = int(min(H, cy + nh/2))
55
+ return x1, y1, x2, y2
56
+
57
+ def annotate_bgr(img_bgr, model, device, conf_thresh=0.6, per_face=True):
58
+ H, W = img_bgr.shape[:2]
59
+ out = img_bgr.copy()
60
+ results = []
61
+ if not per_face:
62
+ label, conf, _ = predict_pil(Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)), model, device)
63
+ color = (0,200,0) if label=="mask" else (0,0,255)
64
+ cv2.putText(out, f"{label.upper()} {conf:.2f}", (20,60),
65
+ cv2.FONT_HERSHEY_SIMPLEX, 1.1, color, 3, cv2.LINE_AA)
66
+ results.append({"bbox":[0,0,W,H],"label":label,"conf":conf})
67
+ return out, results
68
+
69
+ detector = get_detector()
70
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
71
+ det = detector.process(rgb)
72
+ if not det.detections:
73
+ return out, results
74
+
75
+ for d in det.detections:
76
+ bb = d.location_data.relative_bounding_box
77
+ x, y, w, h = int(bb.xmin*W), int(bb.ymin*H), int(bb.width*W), int(bb.height*H)
78
+ x1, y1, x2, y2 = expand_box(x, y, w, h, 1.25, W, H)
79
+ crop = img_bgr[max(0,y1):min(H,y2), max(0,x1):min(W,x2)]
80
+ if crop.size == 0:
81
+ continue
82
+ label, conf, _ = predict_pil(Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)), model, device)
83
+ if conf < conf_thresh:
84
+ continue
85
+ color = (0,200,0) if label=="mask" else (0,0,255)
86
+ cv2.rectangle(out, (x1,y1), (x2,y2), color, 2)
87
+ cv2.putText(out, f"{label.upper()} {conf:.2f}", (x1, max(20,y1-8)),
88
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
89
+ results.append({"bbox":[x1,y1,x2,y2], "label":label, "conf":conf})
90
+ return out, results
91
+
92
+ def bgr_to_png_bytes(img_bgr):
93
+ pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
94
+ buf = io.BytesIO(); pil.save(buf, format="PNG"); buf.seek(0); return buf
95
+
96
+ # Sidebar
97
+ st.sidebar.header("Settings")
98
+ weights_path = st.sidebar.text_input("Model weights (.pt)", value="mask_cls_best.pt")
99
+ conf_thresh = st.sidebar.slider("Confidence threshold", 0.10, 0.99, 0.60, 0.01)
100
+ per_face = st.sidebar.toggle("Per-face boxes (MediaPipe)", value=True)
101
+
102
+ # Load model
103
+ try:
104
+ model, device = load_model(weights_path)
105
+ st.sidebar.success(f"Loaded on {'GPU' if device=='cuda' else 'CPU'}")
106
+ except Exception as e:
107
+ st.sidebar.error(f"Failed to load weights: {e}")
108
+ st.stop()
109
+
110
+ tab1, tab2 = st.tabs(["📷 Image", "🎥 Webcam"])
111
+
112
+ # Image tab
113
+ with tab1:
114
+ st.subheader("Image Inference")
115
+ file = st.file_uploader("Upload an image", type=["jpg","jpeg","png"])
116
+ if file:
117
+ pil = Image.open(file).convert("RGB")
118
+ bgr = cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR)
119
+ out, dets = annotate_bgr(bgr, model, device, conf_thresh=conf_thresh, per_face=per_face)
120
+ st.image(out, caption="Detections", use_container_width=True)
121
+ st.download_button("⬇️ Download annotated image", data=bgr_to_png_bytes(out),
122
+ file_name="mask_detection.png", mime="image/png")
123
+
124
+ # Webcam tab (browser camera)
125
+ class FaceMaskTransformer(VideoTransformerBase):
126
+ def __init__(self):
127
+ self.model, self.device = model, device
128
+ def recv(self, frame):
129
+ img_bgr = frame.to_ndarray(format="bgr24")
130
+ out, _ = annotate_bgr(img_bgr, self.model, self.device,
131
+ conf_thresh=conf_thresh, per_face=per_face)
132
+ return av.VideoFrame.from_ndarray(out, format="bgr24")
133
+
134
+ with tab2:
135
+ st.subheader("Webcam (browser)")
136
+ st.info("Allow camera access in your browser. If video doesn't appear, open the Space over HTTPS and try Chrome.")
137
+ webrtc_streamer(
138
+ key="mask-webrtc",
139
+ video_transformer_factory=FaceMaskTransformer,
140
+ media_stream_constraints={"video": True, "audio": False},
141
+ async_processing=True,
142
+ )