import os # Cấu hình thư mục tạm cho YOLO (Bắt buộc cho HuggingFace) os.environ["YOLO_CONFIG_DIR"] = "/tmp" import gradio as gr import cv2 import numpy as np import torch import torch.nn as nn from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights import albumentations as A from ultralytics import YOLO from datetime import datetime import pandas as pd from collections import deque from pathlib import Path # ============================================================ # 1. MODEL CONFIGURATION (Giữ nguyên logic của bạn) # ============================================================ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODEL_PATH = "best_model_efficientnet_lstm_v2.pth" class EfficientNetLSTM(nn.Module): def __init__(self, hidden_size=256, num_layers=2, dropout=0.5): super(EfficientNetLSTM, self).__init__() weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 self.efficientnet = efficientnet_v2_s(weights=weights) num_features = self.efficientnet.classifier[1].in_features self.efficientnet.classifier = nn.Identity() self.lstm = nn.LSTM(input_size=num_features, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True) self.fc = nn.Sequential( nn.Linear(256*2, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1) ) def forward(self, x): batch_size, num_frames, c, h, w = x.shape x = x.view(batch_size * num_frames, c, h, w) features = self.efficientnet(x) features = features.view(batch_size, num_frames, -1) lstm_out, _ = self.lstm(features) output = self.fc(lstm_out[:, -1, :]) return output.squeeze() # Load Models Global print("⏳ Đang tải models...") try: model = EfficientNetLSTM().to(DEVICE) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() yolo_model = YOLO("yolov8n.pt") print("✅ Đã tải xong models!") except Exception as e: print(f"❌ Lỗi: {e}") model = None yolo_model = None # Transform transform = A.Compose([ A.Resize(height=224, width=224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), A.ToTensorV2(), ]) # ============================================================ # 2. SYSTEM CLASS (QUẢN LÝ TRẠNG THÁI) # ============================================================ class FallDetectionSystem: def __init__(self): # Config self.num_frames = 32 self.conf_thres = 0.5 self.output_dir = Path("fall_videos") self.output_dir.mkdir(exist_ok=True) # Realtime Buffers self.buffer = deque(maxlen=self.num_frames) # Buffer cho model self.pre_buffer = deque(maxlen=30) # Buffer lưu 30 frame trước khi ngã self.no_detect_count = 0 # Recording State self.is_recording = False self.video_writer = None self.current_video_path = None self.fall_start_time = None self.fall_frame_count = 0 # Logging & History self.log_history = [] # Cho realtime text log self.saved_videos = [] # List đường dẫn video đã lưu self.analysis_history = pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"]) def reset_realtime_state(self): """Reset trạng thái khi bật lại camera""" self.buffer.clear() self.pre_buffer.clear() self.is_recording = False if self.video_writer: self.video_writer.release() self.video_writer = None # --- LOGIC TAB 1: VIDEO FILE ANALYSIS --- def analyze_video(self, video_path): if model is None: return "Error loading model", self.analysis_history cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Logic lấy 16 frames (như code cũ) if total_frames >= 32: indices = np.linspace(0, total_frames - 1, 32, dtype=int) else: indices = np.arange(total_frames) for i in range(total_frames): ret, frame = cap.read() if not ret: break if i in indices: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(transform(image=frame_rgb)['image']) cap.release() # Pad frame nếu thiếu while len(frames) < 32: frames.append(frames[-1]) # Predict video_tensor = torch.stack(frames).unsqueeze(0).to(DEVICE) with torch.no_grad(): prob = torch.sigmoid(model(video_tensor)).item() is_fall = prob > 0.5 result_text = "⚠️ PHÁT HIỆN NGÃ" if is_fall else "✅ AN TOÀN" timestamp = datetime.now().strftime("%d/%m/%Y %H:%M") filename = os.path.basename(video_path) # Cập nhật DataFrame new_row = pd.DataFrame({ "Thời gian": [timestamp], "Video": [filename], "Kết quả": [result_text], "Độ tin cậy": [f"{prob*100:.2f}%"] }) self.analysis_history = pd.concat([new_row, self.analysis_history], ignore_index=True) return f"{result_text} ({prob*100:.2f}%)", self.analysis_history # --- LOGIC TAB 2: REALTIME PROCESSING --- def process_frame(self, image): """Hàm xử lý chính cho mỗi frame từ webcam""" if image is None: return image, "", "", [] # 1. Chuẩn bị dữ liệu frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # OpenCV dùng BGR current_time = datetime.now().strftime('%H:%M:%S') # Thêm vào pre-buffer (để ghi video lùi lại quá khứ) self.pre_buffer.append(frame_bgr) # 2. Detect Người (YOLO) results = yolo_model(frame_bgr, verbose=False, conf=self.conf_thres) boxes = results[0].boxes.data.cpu().numpy() person_box = None for x1, y1, x2, y2, conf, cls in boxes: if int(cls) == 0: # Person person_box = (int(x1), int(y1), int(x2), int(y2)) break # Các biến hiển thị UI status_html = "
🔴 {current_time}: Phát hiện ngã ({prob*100:.0f}%)
" # BẮT ĐẦU GHI VIDEO (Nếu chưa ghi) if not self.is_recording: self._start_recording(frame_bgr) # Ghi frame hiện tại if self.video_writer: self.video_writer.write(frame_bgr) self.fall_frame_count += 1 else: # --- BÌNH THƯỜNG --- label = f"An toan ({prob*100:.0f}%)" cv2.putText(frame_bgr, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) # log_entry = f"🟢 {current_time}: Bình thường
" # Uncomment nếu muốn spam log # DỪNG GHI VIDEO (Nếu đang ghi) self._stop_recording_if_active(save=True) else: cv2.putText(frame_bgr, f"Buffering: {len(self.buffer)}/32", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2) # Cập nhật Log if log_entry: self.log_history.insert(0, log_entry) if len(self.log_history) > 50: self.log_history.pop() log_html_output = "".join(self.log_history) # Convert back to RGB for Gradio display frame_rgb_out = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) return frame_rgb_out, status_html, log_html_output, self.saved_videos # --- HELPER METHODS FOR RECORDING --- def _start_recording(self, frame_sample): timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"fall_detect_{timestamp}.mp4" filepath = self.output_dir / filename h, w = frame_sample.shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') # mp4v tương thích tốt hơn self.video_writer = cv2.VideoWriter(str(filepath), fourcc, 20.0, (w, h)) self.is_recording = True self.current_video_path = str(filepath) self.fall_frame_count = 0 # Ghi lại các frame quá khứ (30 frame trước khi ngã) for past_frame in self.pre_buffer: self.video_writer.write(past_frame) def _stop_recording_if_active(self, save=True): if self.is_recording and self.video_writer: self.video_writer.release() self.video_writer = None self.is_recording = False # Logic lưu video if save and self.fall_frame_count > 10: # Chỉ lưu nếu video đủ dài self.saved_videos.insert(0, self.current_video_path) else: # Xóa file rác nếu video quá ngắn try: os.remove(self.current_video_path) except: pass # Khởi tạo hệ thống system = FallDetectionSystem() # ============================================================ # 3. GRADIO UI # ============================================================ # Custom CSS css = """ .status-box { text-align: center; font-size: 1.2em; font-weight: bold; margin-bottom: 10px; } .log-container { height: 300px; overflow-y: auto; background: #222; padding: 10px; border-radius: 8px; border: 1px solid #444; } """ with gr.Blocks(title="Hệ thống Dự đoán Fall", css=css, theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎈 Hệ thống Phát hiện Té ngã (AI Powered)") with gr.Tab("📹 Dự đoán Realtime"): with gr.Row(): with gr.Column(scale=2): # Input Webcam input_cam = gr.Image(sources=["webcam"], type="numpy", label="Camera Input") # Output đã vẽ box output_cam = gr.Image(label="Kết quả Xử lý") with gr.Column(scale=1): # Trạng thái An toàn/Nguy hiểm status_html = gr.HTML(value="