| import gradio as gr |
| import torch |
| from src.predict import predict_from_video |
| from src.islr.islr_model import DummyISLRModel |
| from huggingface_hub import hf_hub_download |
| import torch |
| import os |
| from dotenv import load_dotenv |
| import os |
|
|
| |
| load_dotenv() |
|
|
| |
| hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
|
| |
| os.makedirs("models", exist_ok=True) |
|
|
|
|
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| dataset_models = { |
| "PERU": {"path":"models/demo_model.pt","num_classes":100}, |
| "WLASL": {"path":"models/demo_model.pt","num_classes":100}, |
| } |
|
|
| |
| dataset_examples = { |
| "PERU": [ |
| {"label": "📘 **Glosa: `libro`**", "path": "videos/wlasl/book.mp4"}, |
| {"label": "🏠 **Glosa: `casa`**", "path": "videos/wlasl/book.mp4"}, |
| {"label": "📘 **Glosa: `libro2`**", "path": "videos/wlasl/book.mp4"}, |
| {"label": "🏠 **Glosa: `casa2`**", "path": "videos/wlasl/book.mp4"}, |
| ], |
| "WLASL": [ |
| {"label": "📙 **Glosa: `read`**", "path":"videos/wlasl/book.mp4"}, |
| {"label": "🏫 **Glosa: `school`**", "path":"videos/wlasl/book.mp4"}, |
| {"label": "📙 **Glosa: `read2`**", "path":"videos/wlasl/book.mp4"}, |
| {"label": "🏫 **Glosa: `school2`**", "path":"videos/wlasl/book.mp4"}, |
| ] |
| } |
|
|
| |
| def load_model_and_examples(dataset): |
| model_path = dataset_models.get(dataset)['path'] |
| num_classes = dataset_models.get(dataset)['num_classes'] |
| print("Downloading..") |
| model_path = hf_hub_download(repo_id="CristianLazoQuispe/SignERT", filename=model_path, |
| cache_dir="models", |
| token=hf_token |
| ) |
| print("Downloaded!") |
| |
|
|
|
|
| model = DummyISLRModel(num_classes=num_classes) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.eval() |
| print(f"Model {dataset} Loaded!") |
| examples = dataset_examples.get(dataset, [{"label": "", "path": ""}, {"label": "", "path": ""}]) |
| return ( |
| model, |
| gr.update(visible=True), |
| gr.update(value=examples[0]["path"]), |
| examples[0]["path"], |
| gr.update(value=examples[0]["label"]), |
| gr.update(value=examples[1]["path"]), |
| examples[1]["path"], |
| gr.update(value=examples[1]["label"]), |
| gr.update(value=examples[2]["path"]), |
| examples[2]["path"], |
| gr.update(value=examples[2]["label"]), |
| gr.update(value=examples[3]["path"]), |
| examples[3]["path"], |
| gr.update(value=examples[3]["label"]), |
| gr.update(interactive=True) |
| ) |
|
|
| |
| def classify_video_with_model(video, model): |
| top1, top5_df = predict_from_video(video, model=model) |
| return f"Top-1: {top1}", top5_df |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# 🧠 ISLR Demo con Mediapipe y 100 Clases") |
| gr.Markdown("Sube un video o usa la webcam. El modelo clasificará la seña y mostrará las 5 clases más probables.") |
|
|
| |
| gr.Markdown("## 📁 Filtrar por Language") |
| dataset_selector = gr.Dropdown(choices=list(dataset_examples.keys()), value=None, label="Selecciona el lenguaje") |
|
|
|
|
| |
| current_model = gr.State() |
| video_path_1 = gr.State() |
| video_path_2 = gr.State() |
| video_path_3 = gr.State() |
| video_path_4 = gr.State() |
|
|
| |
| with gr.Row(): |
| video_input = gr.Video(sources=["upload", "webcam"], label="🎥 Video de entrada", width=300, height=400) |
| with gr.Column(): |
| output_text = gr.Text(label="Predicción Top-1") |
| output_table = gr.Label(num_top_classes=5) |
| button_classify = gr.Button("🔍 Clasificar",interactive=False) |
| |
| button_classify.click( |
| fn=classify_video_with_model, |
| inputs=[video_input, current_model], |
| outputs=[output_text, output_table] |
| ) |
|
|
|
|
|
|
|
|
| |
| examples_output = gr.Column(visible=True) |
|
|
| with examples_output: |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=100): |
| m1 = gr.Markdown("📘 **Glosa: **") |
| v1 = gr.Video(interactive=False, width=160, height=120) |
| b1 = gr.Button("Usar", scale=0) |
| with gr.Column(scale=1, min_width=100): |
| m2 = gr.Markdown("🏠 **Glosa: **") |
| v2 = gr.Video(interactive=False, width=160, height=120) |
| b2 = gr.Button("Usar", scale=0) |
| with gr.Column(scale=1, min_width=100): |
| m3 = gr.Markdown("🏠 **Glosa: **") |
| v3 = gr.Video(interactive=False, width=160, height=120) |
| b3 = gr.Button("Usar", scale=0) |
| with gr.Column(scale=1, min_width=100): |
| m4 = gr.Markdown("🏠 **Glosa: **") |
| v4 = gr.Video(interactive=False, width=160, height=120) |
| b4 = gr.Button("Usar", scale=0) |
|
|
| b1.click(fn=lambda path: path, inputs=video_path_1, outputs=video_input) |
| b2.click(fn=lambda path: path, inputs=video_path_2, outputs=video_input) |
| b3.click(fn=lambda path: path, inputs=video_path_3, outputs=video_input) |
| b4.click(fn=lambda path: path, inputs=video_path_4, outputs=video_input) |
|
|
| gr.Markdown("## 📁 Ejemplos de videos") |
| |
| dataset_selector.change( |
| fn=load_model_and_examples, |
| inputs=dataset_selector, |
| outputs=[current_model, examples_output, v1,video_path_1,m1, v2, video_path_2, m2, v3, video_path_3, m3, v4, video_path_4, m4, |
| button_classify |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|