|
|
import gradio as gr |
|
|
import json |
|
|
import torch |
|
|
import decord |
|
|
from transformers import ( |
|
|
AutoProcessor, |
|
|
AutoModelForVideoClassification, |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vm_model_id = "MCG-NJU/videomae-base-finetuned-kinetics" |
|
|
vm_processor = AutoProcessor.from_pretrained(vm_model_id) |
|
|
vm_model = AutoModelForVideoClassification.from_pretrained(vm_model_id) |
|
|
|
|
|
def run_videomae(video): |
|
|
try: |
|
|
vr = decord.VideoReader(video, num_threads=1) |
|
|
frames = [vr[i].asnumpy() for i in range(0, len(vr), max(1, len(vr)//16))] |
|
|
inputs = vm_processor(images=frames, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = vm_model(**inputs) |
|
|
pred_id = outputs.logits.argmax(-1).item() |
|
|
return { |
|
|
"model": "VideoMAE", |
|
|
"status": "ok", |
|
|
"class": vm_model.config.id2label[pred_id], |
|
|
"confidence": float(torch.softmax(outputs.logits, -1)[0, pred_id].item()), |
|
|
} |
|
|
except Exception as e: |
|
|
return {"model": "VideoMAE", "status": "failed", "error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
llava_model_id = "weizhiwang/LLaVA-Video-Llama-3.1-8B" |
|
|
llava_tokenizer = AutoTokenizer.from_pretrained(llava_model_id, trust_remote_code=True) |
|
|
llava_model = AutoModelForCausalLM.from_pretrained( |
|
|
llava_model_id, trust_remote_code=True |
|
|
).half().cuda().eval() |
|
|
|
|
|
def run_llava(video, prompt): |
|
|
try: |
|
|
inputs = llava_tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
output = llava_model.generate(**inputs, max_new_tokens=256) |
|
|
return { |
|
|
"model": "LLaVA-Video-Llama-3.1-8B", |
|
|
"status": "ok", |
|
|
"output": llava_tokenizer.decode(output[0], skip_special_tokens=True), |
|
|
} |
|
|
except Exception as e: |
|
|
return {"model": "LLaVA-Video-Llama-3.1-8B", "status": "failed", "error": str(e)} |
|
|
|
|
|
except Exception as outer_error: |
|
|
llava_load_error = str(outer_error) |
|
|
|
|
|
def run_llava(video, prompt): |
|
|
return { |
|
|
"model": "LLaVA-Video-Llama-3.1-8B", |
|
|
"status": "failed", |
|
|
"error": f"LLaVA not available (requires bleeding-edge Transformers). Details: {llava_load_error}", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_all(video, prompt): |
|
|
results = [] |
|
|
results.append(run_videomae(video)) |
|
|
results.append(run_llava(video, prompt)) |
|
|
return json.dumps(results, indent=2) |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=analyze_all, |
|
|
inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Prompt")], |
|
|
outputs="json", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |