darisdzakwanhoesien's picture
Update app.py
3e5f43f verified
import gradio as gr
import json
import torch
import decord
from transformers import (
AutoProcessor,
AutoModelForVideoClassification,
AutoTokenizer,
AutoModelForCausalLM,
)
# ======================
# 1. VideoMAE
# ======================
vm_model_id = "MCG-NJU/videomae-base-finetuned-kinetics"
vm_processor = AutoProcessor.from_pretrained(vm_model_id)
vm_model = AutoModelForVideoClassification.from_pretrained(vm_model_id)
def run_videomae(video):
try:
vr = decord.VideoReader(video, num_threads=1)
frames = [vr[i].asnumpy() for i in range(0, len(vr), max(1, len(vr)//16))]
inputs = vm_processor(images=frames, return_tensors="pt")
with torch.no_grad():
outputs = vm_model(**inputs)
pred_id = outputs.logits.argmax(-1).item()
return {
"model": "VideoMAE",
"status": "ok",
"class": vm_model.config.id2label[pred_id],
"confidence": float(torch.softmax(outputs.logits, -1)[0, pred_id].item()),
}
except Exception as e:
return {"model": "VideoMAE", "status": "failed", "error": str(e)}
# ======================
# 2. LLaVA-Video-Llama-3.1-8B
# ======================
try:
llava_model_id = "weizhiwang/LLaVA-Video-Llama-3.1-8B"
llava_tokenizer = AutoTokenizer.from_pretrained(llava_model_id, trust_remote_code=True)
llava_model = AutoModelForCausalLM.from_pretrained(
llava_model_id, trust_remote_code=True
).half().cuda().eval()
def run_llava(video, prompt):
try:
inputs = llava_tokenizer(prompt, return_tensors="pt").to("cuda")
output = llava_model.generate(**inputs, max_new_tokens=256)
return {
"model": "LLaVA-Video-Llama-3.1-8B",
"status": "ok",
"output": llava_tokenizer.decode(output[0], skip_special_tokens=True),
}
except Exception as e:
return {"model": "LLaVA-Video-Llama-3.1-8B", "status": "failed", "error": str(e)}
except Exception as outer_error:
llava_load_error = str(outer_error)
def run_llava(video, prompt):
return {
"model": "LLaVA-Video-Llama-3.1-8B",
"status": "failed",
"error": f"LLaVA not available (requires bleeding-edge Transformers). Details: {llava_load_error}",
}
# ======================
# Unified App
# ======================
def analyze_all(video, prompt):
results = []
results.append(run_videomae(video))
results.append(run_llava(video, prompt))
return json.dumps(results, indent=2)
demo = gr.Interface(
fn=analyze_all,
inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Prompt")],
outputs="json",
)
if __name__ == "__main__":
demo.launch()