EnSenas / app.py
fabiosam's picture
Update app.py
9696369 verified
import os
import json
import numpy as np
import cv2
import gradio as gr
import mediapipe as mp
import tensorflow as tf
from tensorflow import keras
# ---------------------------------------------------------
# CONFIG
# ---------------------------------------------------------
MODELS_DIR = "models"
MAX_FRAMES = 20
N_FEATURES = 225 # 75 puntos * (x, y, z)
THRESHOLD = 0.6 # umbral para decidir si "reconoce" o no
# Nombres de los archivos V2
MODEL_FILENAME = "sign_model_lstm_v2.keras"
LABELS_FILENAME = "label_names_v2.json"
FEATURE_MEAN_FILENAME = "feature_mean_v2.npy"
FEATURE_STD_FILENAME = "feature_std_v2.npy"
mp_holistic = mp.solutions.holistic
# ---------------------------------------------------------
# EXTRACCIÓN DE LANDMARKS (MISMA QUE EN EL NOTEBOOK)
# ---------------------------------------------------------
def extract_landmarks_from_results(results):
"""
Convierte los resultados de MediaPipe Holistic en un vector 1D (225,)
con pose (33), mano izq (21) y mano der (21).
Cada punto = (x, y, z) => 75 * 3 = 225 features.
"""
def get_xyz(landmarks, n_points):
if landmarks is None:
data = [[0.0, 0.0, 0.0]] * n_points
else:
data = [[lm.x, lm.y, lm.z] for lm in landmarks]
if len(data) < n_points:
data += [[0.0, 0.0, 0.0]] * (n_points - len(data))
data = data[:n_points]
return data
pose = get_xyz(results.pose_landmarks.landmark if results.pose_landmarks else None, 33)
left_hand = get_xyz(results.left_hand_landmarks.landmark if results.left_hand_landmarks else None, 21)
right_hand = get_xyz(results.right_hand_landmarks.landmark if results.right_hand_landmarks else None, 21)
all_points = pose + left_hand + right_hand
return np.array(all_points, dtype=np.float32).flatten() # (225,)
# ---------------------------------------------------------
# PAD / TRUNCATE (MISMO QUE EN TRAIN)
# ---------------------------------------------------------
def pad_or_truncate(seq, max_frames=MAX_FRAMES):
"""
Asegura que cada secuencia tenga exactamente max_frames frames.
- Si hay más frames, recorta centrado.
- Si hay menos, rellena con ceros al final.
seq: (T, 225)
"""
n = seq.shape[0]
if n == max_frames:
return seq.astype(np.float32)
if n > max_frames:
start = max(0, (n - max_frames) // 2)
return seq[start:start + max_frames].astype(np.float32)
pad_len = max_frames - n
pad = np.zeros((pad_len, seq.shape[1]), dtype=np.float32)
return np.concatenate([seq, pad], axis=0).astype(np.float32)
# ---------------------------------------------------------
# VENTANAS DESLIZANTES SOBRE EL VIDEO
# ---------------------------------------------------------
def make_windows_from_frames(frames_feats, max_frames=MAX_FRAMES, step=5):
"""
frames_feats: lista de vectores (225,) por frame.
Devuelve un array (N_windows, max_frames, 225).
- Si el video es corto, usa pad_or_truncate y genera 1 ventana.
- Si es largo, recorre con una sliding window de tamaño max_frames
y paso 'step'.
"""
seq_full = np.stack(frames_feats, axis=0) # (T, 225)
T = seq_full.shape[0]
windows = []
if T <= max_frames:
windows.append(pad_or_truncate(seq_full, max_frames=max_frames))
else:
for start in range(0, T - max_frames + 1, step):
win = seq_full[start:start + max_frames]
windows.append(win.astype(np.float32))
return np.stack(windows, axis=0).astype(np.float32) # (N, max_frames, 225)
# ---------------------------------------------------------
# CARGA DE MODELO + LABELS + NORMALIZACIÓN (V2)
# ---------------------------------------------------------
def load_model():
model_path = os.path.join(MODELS_DIR, MODEL_FILENAME)
labels_path = os.path.join(MODELS_DIR, LABELS_FILENAME)
mean_path = os.path.join(MODELS_DIR, FEATURE_MEAN_FILENAME)
std_path = os.path.join(MODELS_DIR, FEATURE_STD_FILENAME)
model = keras.models.load_model(model_path)
with open(labels_path, "r") as f:
label_names = json.load(f)
feature_mean = np.load(mean_path)
feature_std = np.load(std_path)
return model, label_names, feature_mean, feature_std
model, label_names, feature_mean, feature_std = load_model()
# ---------------------------------------------------------
# PROCESAR VIDEO → SECUENCIAS NORMALIZADAS
# ---------------------------------------------------------
def process_video_to_sequences(video_file):
"""
Lee el video, extrae landmarks frame a frame, construye
ventanas (N, MAX_FRAMES, 225) y aplica normalización.
"""
cap = cv2.VideoCapture(video_file)
frames_feats = []
with mp_holistic.Holistic(
static_image_mode=False,
model_complexity=1,
enable_segmentation=False,
refine_face_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
) as holis:
while True:
ret, frame = cap.read()
if not ret:
break
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = holis.process(rgb)
feats = extract_landmarks_from_results(results) # (225,)
frames_feats.append(feats)
cap.release()
if len(frames_feats) == 0:
# Nada detectado → una ventana de ceros
windows = np.zeros((1, MAX_FRAMES, N_FEATURES), dtype=np.float32)
else:
windows = make_windows_from_frames(frames_feats, max_frames=MAX_FRAMES, step=5)
# Normalización igual que en entrenamiento
windows_norm = (windows - feature_mean) / feature_std # (N, T, 225)
return windows_norm
# ---------------------------------------------------------
# PREDICCIÓN PROMEDIANDO VENTANAS
# ---------------------------------------------------------
def predict(video):
"""
Función que usa Gradio:
- procesa el video en varias ventanas
- promedia las probabilidades
- aplica umbral THRESHOLD
"""
sequences = process_video_to_sequences(video) # (N, T, 225)
probs_windows = model.predict(sequences, verbose=0) # (N, num_classes)
probs_mean = probs_windows.mean(axis=0) # (num_classes,)
idx = int(np.argmax(probs_mean))
label = label_names[idx]
conf = float(probs_mean[idx])
# Diccionario para el componente gr.Label
probs_dict = {label_names[i]: float(probs_mean[i]) for i in range(len(label_names))}
if conf < THRESHOLD:
text = f"No estoy seguro, señal no reconocida.\nMejor candidata: {label} (confianza {conf:.2f})"
else:
text = f"Predicción: {label} (confianza {conf:.2f})"
return text, probs_dict
# ---------------------------------------------------------
# UI DE GRADIO
# ---------------------------------------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Video(label="Sube un video haciendo la seña"),
outputs=[
gr.Textbox(label="Resultado"),
gr.Label(label="Probabilidades por clase")
],
title="Traductor de Señas LSTM"
)
if __name__ == "__main__":
demo.launch()