Dhanushlevi's picture
Update app.py
5451a13 verified
import cv2
import numpy as np
import torch
import torch.nn as nn
import streamlit as st
from collections import deque
from keras.models import load_model
from ultralytics import YOLO
import csv
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import timm
import base64
import pandas as pd
# Set page title and favicon
st.set_page_config(page_title="Violence Detection and Analysis App", page_icon=":boom:")
# CSS styling
st.markdown(
"""
<style>
.title {
color: #1f78b4;
text-align: center;
font-size: 36px;
margin-bottom: 30px;
}
.header {
color: #1f78b4;
font-size: 24px;
margin-top: 30px;
}
.btn-download {
background-color: #4CAF50;
border: none;
color: white;
padding: 15px 32px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 16px;
margin: 4px 2px;
cursor: pointer;
border-radius: 10px;
}
</style>
""",
unsafe_allow_html=True
)
# Load the pre-trained models for age and gender prediction
age_model = timm.create_model('mobilevitv2_075.cvnets_in1k', pretrained=True, num_classes=5, global_pool='catavgmax')
num_in_features = age_model.get_classifier().in_features
age_model.fc = nn.Sequential(
nn.BatchNorm1d(num_in_features),
nn.Linear(in_features=num_in_features, out_features=512, bias=False),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(0.4),
nn.Linear(in_features=512, out_features=5, bias=False)
)
gender_model = timm.create_model('mobilevitv2_075.cvnets_in1k', pretrained=True, num_classes=2, global_pool='catavgmax')
num_in_features = gender_model.get_classifier().in_features
gender_model.fc = nn.Sequential(
nn.BatchNorm1d(num_in_features),
nn.Linear(in_features=num_in_features, out_features=512, bias=False),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(0.4),
nn.Linear(in_features=512, out_features=2, bias=False)
)
age_model.load_state_dict(torch.load('agetest_mobilevit_V_1.1.pt', map_location=torch.device('cpu')))
gender_model.load_state_dict(torch.load('gendertest_mobilevit_V_1.1.pt', map_location=torch.device('cpu')))
# Load the violence detection model
violence_model = load_model('modelnew (1).h5')
MODEL = "yolov8_people.pt"
yolo_model = YOLO(MODEL)
# Function to preprocess image for detection
def preprocess_image(image, target_size=(256, 256)):
img = cv2.resize(image, target_size, interpolation=cv2.INTER_CUBIC)
img = img.astype(np.float32) / 255.0
img = (img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, axis=0)
return torch.tensor(img, dtype=torch.float32)
# Function to detect violence in a frame
def detect_violence(frame):
true_count = 0
Q = deque(maxlen=128)
frame_copy = frame.copy()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (128, 128)).astype("float32") / 255
preds = violence_model.predict(np.expand_dims(frame, axis=0))[0]
Q.append(preds)
results = np.array(Q).mean(axis=0)
label = (results > 0.50)[0]
text_color = (255, 0, 0) if label else (0, 255, 0)
text = "Violence Detected" if label else "No Violence Detected"
cv2.putText(frame_copy, text, (35, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.25, text_color, 3)
return frame_copy, label
# Function to detect people, their age, and gender in a frame
def detect_people_age_gender(frame):
results = yolo_model(frame)
detections = []
male_count = 0
female_count = 0
for result in results:
boxes = result.boxes.xyxy
confidences = result.boxes.conf
classes = result.boxes.cls
for box, confidence, cls in zip(boxes, confidences, classes):
x1, y1, x2, y2 = map(int, box)
person_crop = frame[y1:y2, x1:x2]
person_crop_resized = preprocess_image(person_crop, target_size=(224, 224))
age_preds = age_model(person_crop_resized)
age_class_index = np.argmax(age_preds.detach().numpy())
age_class_names = ['0-15', '15-30', '30-45', '45-60', '60+']
age = age_class_names[age_class_index]
gender_preds = gender_model(person_crop_resized)
gender = "Male" if np.argmax(gender_preds.detach().numpy()) == 0 else "Female"
detections.append((x1, y1, x2, y2, age, gender))
if gender == 'Male':
male_count += 1
else:
female_count += 1
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"Age: {age}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.putText(frame, f"Gender: {gender}", (x1, y1 - 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return frame, detections, male_count, female_count
# Function to encrypt a file
def encrypt_file(input_file, output_file, key):
cipher = AES.new(key, AES.MODE_CBC)
with open(input_file, 'rb') as infile:
data = infile.read()
# Add padding if needed
if len(data) % 16 != 0:
padding_length = 16 - len(data) % 16
data += bytes([padding_length]) * padding_length
ciphertext = cipher.encrypt(data)
with open(output_file, 'wb') as outfile:
outfile.write(cipher.iv)
outfile.write(ciphertext)
# Function to decrypt a file
def decrypt_file(input_file, output_file, key):
with open(input_file, 'rb') as infile:
iv = infile.read(16)
cipher = AES.new(key, AES.MODE_CBC, iv)
plaintext = cipher.decrypt(infile.read())
# Remove padding
padding_length = plaintext[-1]
plaintext = plaintext[:-padding_length]
with open(output_file, 'wb') as outfile:
outfile.write(plaintext)
# Streamlit UI
def main():
st.markdown('<h1 class="title">Violence Detection and Analysis App</h1>', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload Video", type=["mp4"])
if uploaded_file is not None:
video_path = "temp_video.mp4"
with open(video_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.video(video_path)
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
out = cv2.VideoWriter("temp_output_video.mp4", cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height))
csv_filename = "temp_detection_results.csv"
with open(csv_filename, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Frame', 'Violence', 'People Count', 'Male Count', 'Female Count'])
frame_number = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_violence, violence_label = detect_violence(frame)
frame_result, detections, male_count, female_count = detect_people_age_gender(frame_violence)
writer.writerow([frame_number, violence_label, len(detections), male_count, female_count])
out.write(frame_result)
frame_number += 1
cap.release()
out.release()
key = b'ThisIsASecretKey'
encrypted_file = 'temp_encrypted.csv'
decrypted_file = 'temp_decrypted.csv'
encrypt_file(csv_filename, encrypted_file, key)
decrypt_file(encrypted_file, decrypted_file, key)
# Download processed video
with open("temp_output_video.mp4", "rb") as f:
video_bytes = f.read()
video_b64 = base64.b64encode(video_bytes).decode('utf-8')
href = f'<a class="btn-download" href="data:video/mp4;base64,{video_b64}" download="processed_video.mp4">Download Processed Video</a>'
st.markdown(href, unsafe_allow_html=True)
# Download encrypted CSV
with open(encrypted_file, "rb") as f:
encrypted_data = f.read()
href = f'<a class="btn-download" href="data:file/csv;base64,{base64.b64encode(encrypted_data).decode()}" download="encrypted_csv.csv">Download Encrypted CSV</a>'
st.markdown(href, unsafe_allow_html=True)
# Download decrypted CSV
decrypted_data = pd.read_csv(decrypted_file)
st.markdown('<h2 class="header">Decrypted CSV:</h2>', unsafe_allow_html=True)
st.dataframe(decrypted_data)
if __name__ == "__main__":
main()