import gradio as gr import json import torch import decord from transformers import ( AutoProcessor, AutoModelForVideoClassification, AutoTokenizer, AutoModelForCausalLM, ) # ====================== # 1. VideoMAE # ====================== vm_model_id = "MCG-NJU/videomae-base-finetuned-kinetics" vm_processor = AutoProcessor.from_pretrained(vm_model_id) vm_model = AutoModelForVideoClassification.from_pretrained(vm_model_id) def run_videomae(video): try: vr = decord.VideoReader(video, num_threads=1) frames = [vr[i].asnumpy() for i in range(0, len(vr), max(1, len(vr)//16))] inputs = vm_processor(images=frames, return_tensors="pt") with torch.no_grad(): outputs = vm_model(**inputs) pred_id = outputs.logits.argmax(-1).item() return { "model": "VideoMAE", "status": "ok", "class": vm_model.config.id2label[pred_id], "confidence": float(torch.softmax(outputs.logits, -1)[0, pred_id].item()), } except Exception as e: return {"model": "VideoMAE", "status": "failed", "error": str(e)} # ====================== # 2. LLaVA-Video-Llama-3.1-8B # ====================== try: llava_model_id = "weizhiwang/LLaVA-Video-Llama-3.1-8B" llava_tokenizer = AutoTokenizer.from_pretrained(llava_model_id, trust_remote_code=True) llava_model = AutoModelForCausalLM.from_pretrained( llava_model_id, trust_remote_code=True ).half().cuda().eval() def run_llava(video, prompt): try: inputs = llava_tokenizer(prompt, return_tensors="pt").to("cuda") output = llava_model.generate(**inputs, max_new_tokens=256) return { "model": "LLaVA-Video-Llama-3.1-8B", "status": "ok", "output": llava_tokenizer.decode(output[0], skip_special_tokens=True), } except Exception as e: return {"model": "LLaVA-Video-Llama-3.1-8B", "status": "failed", "error": str(e)} except Exception as outer_error: llava_load_error = str(outer_error) def run_llava(video, prompt): return { "model": "LLaVA-Video-Llama-3.1-8B", "status": "failed", "error": f"LLaVA not available (requires bleeding-edge Transformers). Details: {llava_load_error}", } # ====================== # Unified App # ====================== def analyze_all(video, prompt): results = [] results.append(run_videomae(video)) results.append(run_llava(video, prompt)) return json.dumps(results, indent=2) demo = gr.Interface( fn=analyze_all, inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Prompt")], outputs="json", ) if __name__ == "__main__": demo.launch()