123ahmed commited on
Commit
80adaa7
·
verified ·
1 Parent(s): 1b6b488

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. alert.wav +3 -0
  3. app.py +164 -0
  4. best_vit_lstm.pt +3 -0
  5. requirements.txt +8 -3
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ alert.wav filter=lfs diff=lfs merge=lfs -text
alert.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75f14a2044af42630de43fea45ed720988fec1345eb7ef688a413eaf24db5a7b
3
+ size 3173020
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torch import nn
5
+ import timm
6
+ import cv2
7
+ import numpy as np
8
+ from playsound import playsound
9
+ import threading
10
+ import tempfile
11
+
12
+ # ================================
13
+ # إعداد الصفحة
14
+ # ================================
15
+ st.set_page_config(page_title="Violence Detection System", layout="wide")
16
+
17
+ st.markdown(
18
+ """
19
+ <h1 style='text-align:center; color:#d32f2f;'>Violence Detection System</h1>
20
+ <p style='text-align:center; font-size:18px; color:#444;'>
21
+ Real-time violence detection using <b>ViT + LSTM</b> architecture deployed on HuggingFace Spaces.<br>
22
+ Supports camera input & uploaded videos.
23
+ </p>
24
+ <hr>
25
+ """,
26
+ unsafe_allow_html=True
27
+ )
28
+
29
+ # ================================
30
+ # تحميل النموذج
31
+ # ================================
32
+ MODEL_PATH = "best_vit_lstm.pt"
33
+ ALERT_SOUND = "alert.wav"
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ class ViT_LSTM_Classifier(nn.Module):
38
+ def __init__(self, vit_name="vit_tiny_patch16_224", lstm_hidden=256,
39
+ lstm_layers=1, num_classes=2, dropout=0.3):
40
+ super().__init__()
41
+ self.vit = timm.create_model(vit_name, pretrained=False, num_classes=0)
42
+ self.feat_dim = self.vit.num_features
43
+ self.lstm = nn.LSTM(self.feat_dim, lstm_hidden, lstm_layers,
44
+ batch_first=True, bidirectional=True)
45
+ self.classifier = nn.Sequential(
46
+ nn.Linear(lstm_hidden * 2, 256),
47
+ nn.ReLU(),
48
+ nn.Dropout(dropout),
49
+ nn.Linear(256, num_classes)
50
+ )
51
+
52
+ def forward(self, x):
53
+ B, T, C, H, W = x.shape
54
+ x = x.view(B*T, C, H, W)
55
+ feats = self.vit(x)
56
+ feats = feats.view(B, T, -1)
57
+ out, _ = self.lstm(feats)
58
+ last = out[:, -1, :]
59
+ return self.classifier(last)
60
+
61
+ model = ViT_LSTM_Classifier().to(device)
62
+ state = torch.load(MODEL_PATH, map_location=device)
63
+ model.load_state_dict(state, strict=False)
64
+ model.eval()
65
+
66
+ # ================================
67
+ # التحويلات
68
+ # ================================
69
+ transform = transforms.Compose([
70
+ transforms.ToPILImage(),
71
+ transforms.Resize((224, 224)),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
74
+ ])
75
+
76
+ def play_alert():
77
+ try:
78
+ playsound(ALERT_SOUND)
79
+ except:
80
+ pass
81
+
82
+ def predict_frames(frames):
83
+ seq_len = 8
84
+ if len(frames) < seq_len:
85
+ return 0
86
+
87
+ clip = frames[-seq_len:]
88
+ clip = torch.stack(clip).unsqueeze(0).to(device)
89
+
90
+ with torch.no_grad():
91
+ output = model(clip)
92
+ pred = torch.argmax(output, dim=1).item()
93
+ return pred
94
+
95
+ # ================================
96
+ # واجهة المستخدم
97
+ # ================================
98
+ st.sidebar.header("Mode Selection")
99
+ mode = st.sidebar.radio("Choose Input Mode", ["Open Camera", "Upload Video"])
100
+
101
+ if mode == "Open Camera":
102
+ picture = st.camera_input("Open your camera and capture a short video clip")
103
+
104
+ if picture:
105
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
106
+ tmp.write(picture.read())
107
+ video_path = tmp.name
108
+
109
+ cap = cv2.VideoCapture(video_path)
110
+ frames = []
111
+
112
+ while True:
113
+ ret, frame = cap.read()
114
+ if not ret:
115
+ break
116
+
117
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
118
+ tensor = transform(rgb)
119
+ frames.append(tensor)
120
+
121
+ st.image(rgb)
122
+
123
+ cap.release()
124
+
125
+ pred = predict_frames(frames)
126
+ label = "Violent" if pred == 1 else "Non-Violent"
127
+
128
+ if pred == 1:
129
+ st.error("⚠️ Violent Behavior Detected!")
130
+ threading.Thread(target=play_alert, daemon=True).start()
131
+ else:
132
+ st.success("✔️ Normal Activity")
133
+
134
+ elif mode == "Upload Video":
135
+ uploaded = st.file_uploader("Upload MP4 Video", type=["mp4"])
136
+
137
+ if uploaded:
138
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
139
+ tmp.write(uploaded.read())
140
+ video_path = tmp.name
141
+
142
+ cap = cv2.VideoCapture(video_path)
143
+ frames = []
144
+
145
+ st.info("Processing video...")
146
+
147
+ while True:
148
+ ret, frame = cap.read()
149
+ if not ret:
150
+ break
151
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
152
+ tensor = transform(rgb)
153
+ frames.append(tensor)
154
+
155
+ cap.release()
156
+
157
+ pred = predict_frames(frames)
158
+ label = "Violent" if pred == 1 else "Non-Violent"
159
+
160
+ if pred == 1:
161
+ st.error("⚠️ Violence Detected!")
162
+ threading.Thread(target=play_alert, daemon=True).start()
163
+ else:
164
+ st.success("✔️ Non-Violent")
best_vit_lstm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bafb8ce0f006cab77c6f24f8823d6ad755afb017d27ec23f085df491d21f5b31
3
+ size 26374863
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ timm
5
+ opencv-python-headless
6
+ numpy
7
+ playsound==1.2.2
8
+ Pillow