| import sys |
| sys.path.append('.') |
|
|
| import torch |
| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoProcessor |
| import argparse |
| import os |
|
|
| class SimpleVideoLLaMA3Interface: |
| def __init__(self, model_path): |
| print(f"Loading model from {model_path}...") |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| ) |
| self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
| print("Model loaded successfully!") |
|
|
| self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp") |
| self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv") |
|
|
| @torch.inference_mode() |
| def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256): |
| if not messages or len(messages) == 0: |
| return messages |
| |
| |
| conversation = [] |
| |
| |
| i = 0 |
| while i < len(messages): |
| if messages[i]["role"] == "user": |
| |
| user_content = [] |
| |
| while i < len(messages) and messages[i]["role"] == "user": |
| msg = messages[i] |
| print(f"DEBUG: Processing user message {i}: {msg}") |
| print(f"DEBUG: Content type: {type(msg['content'])}") |
| print(f"DEBUG: Content value: {msg['content']}") |
| |
| |
| if isinstance(msg["content"], str): |
| print(f"DEBUG: Adding text: {msg['content']}") |
| user_content.append({"type": "text", "text": msg["content"]}) |
| elif isinstance(msg["content"], tuple) and len(msg["content"]) > 0: |
| |
| file_path = msg["content"][0] |
| print(f"Processing file from tuple: {file_path}") |
| |
| |
| if not os.path.exists(file_path): |
| print(f"ERROR: File does not exist: {file_path}") |
| user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"}) |
| elif file_path.lower().endswith(self.video_formats): |
| print(f"β
DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}") |
| user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}}) |
| elif file_path.lower().endswith(self.image_formats): |
| print(f"β
DETECTED IMAGE: Adding image: {file_path}") |
| user_content.append({"type": "image", "image": {"image_path": file_path}}) |
| else: |
| print(f"β UNKNOWN FILE TYPE: {file_path}") |
| user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"}) |
| elif isinstance(msg["content"], dict) and "path" in msg["content"]: |
| |
| file_path = msg["content"]["path"] |
| print(f"Processing file from dict: {file_path}") |
| |
| if not os.path.exists(file_path): |
| print(f"ERROR: File does not exist: {file_path}") |
| user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"}) |
| elif file_path.lower().endswith(self.video_formats): |
| print(f"β
DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}") |
| user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}}) |
| elif file_path.lower().endswith(self.image_formats): |
| print(f"β
DETECTED IMAGE: Adding image: {file_path}") |
| user_content.append({"type": "image", "image": {"image_path": file_path}}) |
| else: |
| print(f"β UNKNOWN FILE TYPE: {file_path}") |
| user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"}) |
| |
| i += 1 |
| |
| |
| if user_content: |
| conversation.append({"role": "user", "content": user_content}) |
| print(f"π Added user turn with {len(user_content)} items: {[item.get('type', 'unknown') for item in user_content]}") |
| |
| elif messages[i]["role"] == "assistant": |
| |
| conversation.append({"role": "assistant", "content": messages[i]["content"]}) |
| print(f"π€ Added assistant turn: {messages[i]['content'][:50]}...") |
| i += 1 |
|
|
| if not conversation: |
| return messages |
|
|
| try: |
| |
| print(f"Conversation structure: {len(conversation)} turns") |
| for i, turn in enumerate(conversation): |
| role = turn["role"] |
| if role == "user": |
| content_types = [item.get("type", "unknown") for item in turn["content"] if isinstance(item, dict)] |
| print(f"Turn {i}: {role} - {content_types}") |
| else: |
| print(f"Turn {i}: {role} - text response") |
| |
| inputs = self.processor( |
| conversation=conversation, |
| add_system_prompt=True, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
| if "pixel_values" in inputs: |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) |
|
|
| output_ids = self.model.generate( |
| **inputs, |
| do_sample=do_sample, |
| temperature=temperature, |
| top_p=top_p, |
| max_new_tokens=max_new_tokens |
| ) |
| response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
| |
| |
| |
| for indicator in ["assistant", "Assistant", "ASSISTANT"]: |
| if indicator in response: |
| response = response.split(indicator)[-1].strip() |
| break |
| |
| |
| response = response.lstrip(":") |
| response = response.lstrip() |
| |
| messages.append({"role": "assistant", "content": response}) |
| return messages |
| |
| except Exception as e: |
| error_msg = f"Error: {str(e)}" |
| print(f"Error in prediction: {error_msg}") |
| messages.append({"role": "assistant", "content": error_msg}) |
| return messages |
|
|
| def create_interface(self): |
| with gr.Blocks(title="VideoLLaMA3 AI Curator") as interface: |
| gr.Markdown("# π¬ VideoLLaMA3 AI Curator\nUpload images or videos and ask questions!") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot(type="messages", height=600) |
| |
| with gr.Column(scale=1): |
| with gr.Tab("Input"): |
| video_input = gr.Video(sources=["upload"], label="Upload Video") |
| image_input = gr.Image(sources=["upload"], type="filepath", label="Upload Image") |
| text_input = gr.Textbox(label="Your Message", placeholder="Ask about the image/video or chat...") |
| submit_btn = gr.Button("Send", variant="primary") |
| |
| with gr.Tab("Settings"): |
| do_sample = gr.Checkbox(value=True, label="Do Sample") |
| temperature = gr.Slider(0.0, 1.0, value=0.7, label="Temperature") |
| top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P") |
| max_tokens = gr.Slider(256, 8192, value=4096, step=64, label="Max Tokens") |
| fps = gr.Slider(0.5, 15.0, value=10.0, label="Video FPS") |
| max_frames = gr.Slider(32, 512, value=256, step=8, label="Max Frames") |
|
|
| def add_file(history, file): |
| if file: |
| print(f"DEBUG: Gradio file input: {file}") |
| print(f"DEBUG: File type: {type(file)}") |
| history.append({"role": "user", "content": {"path": file}}) |
| return history, None |
|
|
| def add_text(history, text): |
| if text.strip(): |
| history.append({"role": "user", "content": text}) |
| return history, "" |
|
|
| def respond(history, do_sample, temperature, top_p, max_tokens, fps, max_frames): |
| |
| if history and history[-1]["role"] == "user": |
| return self.predict(history, do_sample, temperature, top_p, max_tokens, fps, max_frames) |
| return history |
|
|
| video_input.change(add_file, [chatbot, video_input], [chatbot, video_input]) |
| image_input.change(add_file, [chatbot, image_input], [chatbot, image_input]) |
| text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then( |
| respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot] |
| ) |
| submit_btn.click(add_text, [chatbot, text_input], [chatbot, text_input]).then( |
| respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot] |
| ) |
|
|
| return interface |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoLLaMA3-7B") |
| parser.add_argument("--port", type=int, default=7860) |
| parser.add_argument("--share", action="store_true") |
| args = parser.parse_args() |
|
|
| app = SimpleVideoLLaMA3Interface(args.model_path) |
| interface = app.create_interface() |
| interface.launch(server_port=args.port, share=args.share, server_name="0.0.0.0") |