Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import streamlit as st | |
| from collections import deque | |
| from keras.models import load_model | |
| from ultralytics import YOLO | |
| import csv | |
| from Crypto.Cipher import AES | |
| from Crypto.Random import get_random_bytes | |
| import timm | |
| import base64 | |
| import pandas as pd | |
| # Set page title and favicon | |
| st.set_page_config(page_title="Violence Detection and Analysis App", page_icon=":boom:") | |
| # CSS styling | |
| st.markdown( | |
| """ | |
| <style> | |
| .title { | |
| color: #1f78b4; | |
| text-align: center; | |
| font-size: 36px; | |
| margin-bottom: 30px; | |
| } | |
| .header { | |
| color: #1f78b4; | |
| font-size: 24px; | |
| margin-top: 30px; | |
| } | |
| .btn-download { | |
| background-color: #4CAF50; | |
| border: none; | |
| color: white; | |
| padding: 15px 32px; | |
| text-align: center; | |
| text-decoration: none; | |
| display: inline-block; | |
| font-size: 16px; | |
| margin: 4px 2px; | |
| cursor: pointer; | |
| border-radius: 10px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Load the pre-trained models for age and gender prediction | |
| age_model = timm.create_model('mobilevitv2_075.cvnets_in1k', pretrained=True, num_classes=5, global_pool='catavgmax') | |
| num_in_features = age_model.get_classifier().in_features | |
| age_model.fc = nn.Sequential( | |
| nn.BatchNorm1d(num_in_features), | |
| nn.Linear(in_features=num_in_features, out_features=512, bias=False), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(0.4), | |
| nn.Linear(in_features=512, out_features=5, bias=False) | |
| ) | |
| gender_model = timm.create_model('mobilevitv2_075.cvnets_in1k', pretrained=True, num_classes=2, global_pool='catavgmax') | |
| num_in_features = gender_model.get_classifier().in_features | |
| gender_model.fc = nn.Sequential( | |
| nn.BatchNorm1d(num_in_features), | |
| nn.Linear(in_features=num_in_features, out_features=512, bias=False), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(0.4), | |
| nn.Linear(in_features=512, out_features=2, bias=False) | |
| ) | |
| age_model.load_state_dict(torch.load('agetest_mobilevit_V_1.1.pt', map_location=torch.device('cpu'))) | |
| gender_model.load_state_dict(torch.load('gendertest_mobilevit_V_1.1.pt', map_location=torch.device('cpu'))) | |
| # Load the violence detection model | |
| violence_model = load_model('modelnew (1).h5') | |
| MODEL = "yolov8_people.pt" | |
| yolo_model = YOLO(MODEL) | |
| # Function to preprocess image for detection | |
| def preprocess_image(image, target_size=(256, 256)): | |
| img = cv2.resize(image, target_size, interpolation=cv2.INTER_CUBIC) | |
| img = img.astype(np.float32) / 255.0 | |
| img = (img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = np.expand_dims(img, axis=0) | |
| return torch.tensor(img, dtype=torch.float32) | |
| # Function to detect violence in a frame | |
| def detect_violence(frame): | |
| true_count = 0 | |
| Q = deque(maxlen=128) | |
| frame_copy = frame.copy() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = cv2.resize(frame, (128, 128)).astype("float32") / 255 | |
| preds = violence_model.predict(np.expand_dims(frame, axis=0))[0] | |
| Q.append(preds) | |
| results = np.array(Q).mean(axis=0) | |
| label = (results > 0.50)[0] | |
| text_color = (255, 0, 0) if label else (0, 255, 0) | |
| text = "Violence Detected" if label else "No Violence Detected" | |
| cv2.putText(frame_copy, text, (35, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.25, text_color, 3) | |
| return frame_copy, label | |
| # Function to detect people, their age, and gender in a frame | |
| def detect_people_age_gender(frame): | |
| results = yolo_model(frame) | |
| detections = [] | |
| male_count = 0 | |
| female_count = 0 | |
| for result in results: | |
| boxes = result.boxes.xyxy | |
| confidences = result.boxes.conf | |
| classes = result.boxes.cls | |
| for box, confidence, cls in zip(boxes, confidences, classes): | |
| x1, y1, x2, y2 = map(int, box) | |
| person_crop = frame[y1:y2, x1:x2] | |
| person_crop_resized = preprocess_image(person_crop, target_size=(224, 224)) | |
| age_preds = age_model(person_crop_resized) | |
| age_class_index = np.argmax(age_preds.detach().numpy()) | |
| age_class_names = ['0-15', '15-30', '30-45', '45-60', '60+'] | |
| age = age_class_names[age_class_index] | |
| gender_preds = gender_model(person_crop_resized) | |
| gender = "Male" if np.argmax(gender_preds.detach().numpy()) == 0 else "Female" | |
| detections.append((x1, y1, x2, y2, age, gender)) | |
| if gender == 'Male': | |
| male_count += 1 | |
| else: | |
| female_count += 1 | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| cv2.putText(frame, f"Age: {age}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
| cv2.putText(frame, f"Gender: {gender}", (x1, y1 - 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
| return frame, detections, male_count, female_count | |
| # Function to encrypt a file | |
| def encrypt_file(input_file, output_file, key): | |
| cipher = AES.new(key, AES.MODE_CBC) | |
| with open(input_file, 'rb') as infile: | |
| data = infile.read() | |
| # Add padding if needed | |
| if len(data) % 16 != 0: | |
| padding_length = 16 - len(data) % 16 | |
| data += bytes([padding_length]) * padding_length | |
| ciphertext = cipher.encrypt(data) | |
| with open(output_file, 'wb') as outfile: | |
| outfile.write(cipher.iv) | |
| outfile.write(ciphertext) | |
| # Function to decrypt a file | |
| def decrypt_file(input_file, output_file, key): | |
| with open(input_file, 'rb') as infile: | |
| iv = infile.read(16) | |
| cipher = AES.new(key, AES.MODE_CBC, iv) | |
| plaintext = cipher.decrypt(infile.read()) | |
| # Remove padding | |
| padding_length = plaintext[-1] | |
| plaintext = plaintext[:-padding_length] | |
| with open(output_file, 'wb') as outfile: | |
| outfile.write(plaintext) | |
| # Streamlit UI | |
| def main(): | |
| st.markdown('<h1 class="title">Violence Detection and Analysis App</h1>', unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("Upload Video", type=["mp4"]) | |
| if uploaded_file is not None: | |
| video_path = "temp_video.mp4" | |
| with open(video_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.video(video_path) | |
| cap = cv2.VideoCapture(video_path) | |
| frame_width = int(cap.get(3)) | |
| frame_height = int(cap.get(4)) | |
| out = cv2.VideoWriter("temp_output_video.mp4", cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height)) | |
| csv_filename = "temp_detection_results.csv" | |
| with open(csv_filename, mode='w', newline='') as file: | |
| writer = csv.writer(file) | |
| writer.writerow(['Frame', 'Violence', 'People Count', 'Male Count', 'Female Count']) | |
| frame_number = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_violence, violence_label = detect_violence(frame) | |
| frame_result, detections, male_count, female_count = detect_people_age_gender(frame_violence) | |
| writer.writerow([frame_number, violence_label, len(detections), male_count, female_count]) | |
| out.write(frame_result) | |
| frame_number += 1 | |
| cap.release() | |
| out.release() | |
| key = b'ThisIsASecretKey' | |
| encrypted_file = 'temp_encrypted.csv' | |
| decrypted_file = 'temp_decrypted.csv' | |
| encrypt_file(csv_filename, encrypted_file, key) | |
| decrypt_file(encrypted_file, decrypted_file, key) | |
| # Download processed video | |
| with open("temp_output_video.mp4", "rb") as f: | |
| video_bytes = f.read() | |
| video_b64 = base64.b64encode(video_bytes).decode('utf-8') | |
| href = f'<a class="btn-download" href="data:video/mp4;base64,{video_b64}" download="processed_video.mp4">Download Processed Video</a>' | |
| st.markdown(href, unsafe_allow_html=True) | |
| # Download encrypted CSV | |
| with open(encrypted_file, "rb") as f: | |
| encrypted_data = f.read() | |
| href = f'<a class="btn-download" href="data:file/csv;base64,{base64.b64encode(encrypted_data).decode()}" download="encrypted_csv.csv">Download Encrypted CSV</a>' | |
| st.markdown(href, unsafe_allow_html=True) | |
| # Download decrypted CSV | |
| decrypted_data = pd.read_csv(decrypted_file) | |
| st.markdown('<h2 class="header">Decrypted CSV:</h2>', unsafe_allow_html=True) | |
| st.dataframe(decrypted_data) | |
| if __name__ == "__main__": | |
| main() | |