Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from pathlib import Path | |
| import gc | |
| import torch | |
| import gradio as gr | |
| # Allow importing your models package | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from models import load_model | |
| from models.base import BaseVideoModel | |
| # ---------------------- | |
| # CONFIG | |
| # ---------------------- | |
| DEVICE_MAP = "cuda:0" | |
| VIDEO_DIR = str(Path(__file__).parent / "videos") | |
| FPS = 1.0 | |
| MAX_NEW_TOKENS = 512 | |
| TEMPERATURE = 0.01 | |
| # ---------------------- | |
| # Model loading with quantization support | |
| # ---------------------- | |
| model: BaseVideoModel = None | |
| current_model_name = "Qwen3-VL-4B-Instruct" | |
| current_quantization = "16-bit" | |
| def load_model_with_quantization( | |
| model_name: str, | |
| quantization: str | |
| ): | |
| """Load or reload the model with specified quantization""" | |
| global model, current_model_name, current_quantization | |
| # Free GPU memory if model already exists | |
| if model is not None: | |
| print("Unloading existing model and freeing GPU memory...") | |
| del model | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print("GPU memory cleared.") | |
| load_8bit = False | |
| load_4bit = False | |
| if quantization == "8-bit": | |
| load_8bit = True | |
| elif quantization == "4-bit": | |
| load_4bit = True | |
| # else: 16-bit (normal) - both flags remain False | |
| print(f"Loading {model_name} with {quantization} quantization...") | |
| model_path = model_name | |
| # Load the HF version of LLaVA-Video-7B instead of the default version, for transformers v5 compatibility | |
| # For the Qwen models, load the model from the Qwen directory | |
| if model_name == "LLaVA-Video-7B-Qwen2": | |
| model_path = "Isotr0py/LLaVA-Video-7B-Qwen2-hf" | |
| elif model_name.startswith("Qwen"): | |
| model_path = f"Qwen/{model_name}" | |
| model = load_model( | |
| model_path, | |
| device_map=DEVICE_MAP, | |
| load_8bit=load_8bit, | |
| load_4bit=load_4bit, | |
| ) | |
| current_model_name = model_name | |
| current_quantization = quantization | |
| print(f"{model_name} loaded with {quantization} quantization.") | |
| return f"β {model_name} loaded successfully with {quantization} quantization" | |
| # Load model initially with 16-bit (normal) | |
| load_model_with_quantization(current_model_name, current_quantization) | |
| # ---------------------- | |
| # Collect video IDs | |
| # ---------------------- | |
| VIDEO_IDS = sorted([ | |
| os.path.splitext(f)[0] | |
| for f in os.listdir(VIDEO_DIR) | |
| if f.endswith(".mp4") | |
| ]) | |
| # ---------------------- | |
| # Helpers | |
| # ---------------------- | |
| def get_video_path(video_id: str): | |
| if not video_id: | |
| return None | |
| path = os.path.join(VIDEO_DIR, video_id + ".mp4") | |
| return path if os.path.exists(path) else None | |
| # ---------------------- | |
| # Inference function | |
| # ---------------------- | |
| def video_qa( | |
| video_id: str, | |
| prompt: str, | |
| video_mode: str, | |
| fps: float, | |
| num_frames: int, | |
| max_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| ) -> str: | |
| if not video_id: | |
| return "β Please select a video ID." | |
| if not prompt.strip(): | |
| return "β Please enter a prompt." | |
| video_path = get_video_path(video_id) | |
| if video_path is None: | |
| return f"β Video not found: {video_id}.mp4" | |
| try: | |
| # Prepare generation config | |
| generation_config = { | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| } | |
| # Add video_mode if supported by the model | |
| kwargs = { | |
| "prompt": prompt, | |
| "video_path": video_path, | |
| "fps": fps, | |
| "num_frames": num_frames, | |
| **generation_config | |
| } | |
| # Try to add video_mode (for Qwen models) | |
| try: | |
| response = model.chat(**kwargs, video_mode=video_mode) | |
| except TypeError: | |
| # If video_mode is not supported, fall back to without it | |
| response = model.chat(**kwargs) | |
| return response | |
| except Exception as e: | |
| return f"β Error during inference: {str(e)}" | |
| # ---------------------- | |
| # Gradio UI | |
| # ---------------------- | |
| with gr.Blocks(title="Video Inference Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## π₯ Video Inference") | |
| with gr.Row(): | |
| # LEFT COLUMN | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Video Selection") | |
| video_id = gr.Dropdown( | |
| choices=VIDEO_IDS, | |
| label="Video ID", | |
| filterable=True, | |
| interactive=True, | |
| value=VIDEO_IDS[0] if VIDEO_IDS else None | |
| ) | |
| video_player = gr.Video( | |
| label="Selected Video", | |
| autoplay=False, | |
| height=300 | |
| ) | |
| gr.Markdown("### π€ Model Name") | |
| model_name_radio = gr.Radio( | |
| choices=[ | |
| "Qwen3-VL-4B-Instruct", | |
| "Qwen3-VL-8B-Instruct", | |
| "Qwen3-VL-2B-Thinking", | |
| "Qwen3-VL-4B-Thinking", | |
| "LLaVA-Video-7B-Qwen2" | |
| ], | |
| value="Qwen3-VL-4B-Instruct", | |
| label="π€ Model Name", | |
| info="Select the model to use for inference" | |
| ) | |
| gr.Markdown("### βοΈ Model Parameters") | |
| quantization_radio = gr.Radio( | |
| choices=["16-bit", "8-bit", "4-bit"], | |
| value="16-bit", | |
| label="π§ Model Quantization", | |
| info="16-bit: Default precision, 8-bit/4-bit: Reduced memory usage" | |
| ) | |
| reload_button = gr.Button("π Reload Model", variant="secondary") | |
| reload_status = gr.Textbox( | |
| label="Model Status", | |
| value=f"{current_model_name} loaded with {current_quantization} quantization", | |
| interactive=False, | |
| lines=1 | |
| ) | |
| fps_slider = gr.Slider( | |
| minimum=0.5, | |
| maximum=10.0, | |
| step=0.5, | |
| value=FPS, | |
| label="ποΈ Frames Per Second (FPS)", | |
| info="Sample rate for video frames" | |
| ) | |
| video_mode_radio = gr.Radio( | |
| choices=["video", "frames"], | |
| value="video", | |
| label="πΉ Video Mode", | |
| info="'video' for FPS-based, 'frames' for fixed count" | |
| ) | |
| num_frames_slider = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=8, | |
| label="πΌοΈ Number of Frames", | |
| info="Fixed frame count (used when video_mode='frames')" | |
| ) | |
| with gr.Accordion("π§ Advanced Settings", open=False): | |
| max_tokens_slider = gr.Slider( | |
| minimum=128, | |
| maximum=2048, | |
| step=128, | |
| value=MAX_NEW_TOKENS, | |
| label="Max New Tokens", | |
| info="Maximum length of generated response" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.01, | |
| maximum=2.0, | |
| step=0.01, | |
| value=TEMPERATURE, | |
| label="π‘οΈ Temperature", | |
| info="Higher = more creative, lower = more focused" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="π Top-K", | |
| info="Sample from top K tokens" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.95, | |
| label="π― Top-P (Nucleus)", | |
| info="Cumulative probability threshold" | |
| ) | |
| # RIGHT COLUMN | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¬ Question & Answer") | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Ask a question about the selected video...", | |
| lines=4, | |
| value="Describe what is happening in this video." | |
| ) | |
| answer = gr.Textbox( | |
| label="Model Answer", | |
| lines=20, | |
| interactive=False | |
| ) | |
| run = gr.Button("π Run Inference", variant="primary", size="lg") | |
| gr.Markdown(""" | |
| --- | |
| **βΉοΈ Tips:** | |
| - **Quantization:** 16-bit (full precision), 8-bit (2x memory savings), 4-bit (4x memory savings with slight quality loss) | |
| - Adjust FPS to control video sampling rate (higher = more frames, slower inference) | |
| - Use video_mode='frames' for fixed frame count (useful for very long videos) | |
| - Temperature: Lower (0.01-0.5) for factual, higher (0.7-1.5) for creative responses | |
| - Top-K and Top-P control output diversity | |
| """) | |
| # Update video player when dropdown changes | |
| video_id.change( | |
| fn=get_video_path, | |
| inputs=video_id, | |
| outputs=video_player | |
| ) | |
| # Reload model with new quantization | |
| reload_button.click( | |
| fn=load_model_with_quantization, | |
| inputs=[ | |
| model_name_radio, | |
| quantization_radio, | |
| ], | |
| outputs=reload_status | |
| ) | |
| # Run inference | |
| run.click( | |
| fn=video_qa, | |
| inputs=[ | |
| video_id, | |
| prompt, | |
| video_mode_radio, | |
| fps_slider, | |
| num_frames_slider, | |
| max_tokens_slider, | |
| temperature_slider, | |
| top_k_slider, | |
| top_p_slider, | |
| ], | |
| outputs=answer | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |
| # #--------------- | |
| # #--------------- | |
| # #--------------- | |
| # # Feb 5, 2026 | |
| # #--------------- | |
| # import os | |
| # import sys | |
| # import json | |
| # from pathlib import Path | |
| # import gradio as gr | |
| # # Allow importing your models package | |
| # sys.path.insert(0, str(Path(__file__).parent)) | |
| # from models import load_model | |
| # from models.base import BaseVideoModel | |
| # # ---------------------- | |
| # # CONFIG | |
| # # ---------------------- | |
| # QWEN_MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct" | |
| # LLAVA_MODEL_PATH = "lmms-lab/LLaVA-Video-7B-Qwen2" | |
| # DEVICE_MAP_QWEN = "cuda:0" | |
| # DEVICE_MAP_LLAVA = "cuda:0" # Both models on same GPU | |
| # VIDEO_DIR = "/home/raman/Gradio_Qwen3vl4bInstruct/videos" | |
| # LABELS_JSON = "/home/raman/Gradio_Qwen3vl4bInstruct/SSv2_prepost_sampled.json" | |
| # DEFAULT_FPS = 1.0 | |
| # MAX_NEW_TOKENS = 512 | |
| # TEMPERATURE = 0.01 | |
| # # ---------------------- | |
| # # Load video labels | |
| # # ---------------------- | |
| # print("Loading video labels...") | |
| # video_labels = {} | |
| # try: | |
| # with open(LABELS_JSON, 'r') as f: | |
| # labels_data = json.load(f) | |
| # for item in labels_data: | |
| # video_labels[item['id']] = { | |
| # 'label': item['label'], | |
| # 'template': item.get('template', ''), | |
| # 'action_group': item.get('action_group', '') | |
| # } | |
| # print(f"Loaded {len(video_labels)} video labels.") | |
| # except Exception as e: | |
| # print(f"Warning: Could not load labels JSON: {e}") | |
| # # ---------------------- | |
| # # Load models | |
| # # ---------------------- | |
| # print("Loading Qwen3-VL-4B-Instruct...") | |
| # qwen_model: BaseVideoModel = load_model( | |
| # QWEN_MODEL_PATH, | |
| # device_map=DEVICE_MAP_QWEN, | |
| # ) | |
| # print("Qwen model loaded.") | |
| # print("Loading LLaVA-Video-7B...") | |
| # llava_model: BaseVideoModel = load_model( | |
| # LLAVA_MODEL_PATH, | |
| # device_map=DEVICE_MAP_LLAVA, | |
| # ) | |
| # print("LLaVA model loaded.") | |
| # # ---------------------- | |
| # # Collect video IDs | |
| # # ---------------------- | |
| # VIDEO_IDS = sorted([ | |
| # os.path.splitext(f)[0] | |
| # for f in os.listdir(VIDEO_DIR) | |
| # if f.endswith(".mp4") | |
| # ]) | |
| # print(f"Found {len(VIDEO_IDS)} videos.") | |
| # # ---------------------- | |
| # # Helpers | |
| # # ---------------------- | |
| # def get_video_path(video_id: str): | |
| # if not video_id: | |
| # return None | |
| # path = os.path.join(VIDEO_DIR, video_id + ".mp4") | |
| # return path if os.path.exists(path) else None | |
| # def get_video_label(video_id: str): | |
| # if not video_id: | |
| # return "" | |
| # info = video_labels.get(video_id, {}) | |
| # label = info.get('label', 'No label available') | |
| # action_group = info.get('action_group', '') | |
| # if action_group: | |
| # return f"**Label:** {label}\n\n**Action Group:** {action_group}" | |
| # return f"**Label:** {label}" | |
| # def update_video_info(video_id: str): | |
| # """Returns video path and label when video is selected""" | |
| # video_path = get_video_path(video_id) | |
| # label = get_video_label(video_id) | |
| # return video_path, label | |
| # # ---------------------- | |
| # # Inference functions | |
| # # ---------------------- | |
| # def qwen_inference(video_id: str, prompt: str, fps: float) -> str: | |
| # if not video_id: | |
| # return "β Please select a video ID." | |
| # if not prompt.strip(): | |
| # return "β Please enter a prompt." | |
| # video_path = get_video_path(video_id) | |
| # if video_path is None: | |
| # return f"β Video not found: {video_id}.mp4" | |
| # try: | |
| # response = qwen_model.chat( | |
| # prompt=prompt, | |
| # video_path=video_path, | |
| # fps=fps, | |
| # max_new_tokens=MAX_NEW_TOKENS, | |
| # temperature=TEMPERATURE, | |
| # ) | |
| # return response | |
| # except Exception as e: | |
| # return f"β Error during Qwen inference: {str(e)}" | |
| # def llava_inference(video_id: str, prompt: str, fps: float) -> str: | |
| # if not video_id: | |
| # return "β Please select a video ID." | |
| # if not prompt.strip(): | |
| # return "β Please enter a prompt." | |
| # video_path = get_video_path(video_id) | |
| # if video_path is None: | |
| # return f"β Video not found: {video_id}.mp4" | |
| # try: | |
| # response = llava_model.chat( | |
| # prompt=prompt, | |
| # video_path=video_path, | |
| # fps=fps, | |
| # max_new_tokens=MAX_NEW_TOKENS, | |
| # temperature=TEMPERATURE, | |
| # ) | |
| # return response | |
| # except Exception as e: | |
| # return f"β Error during LLaVA inference: {str(e)}" | |
| # # ---------------------- | |
| # # Gradio UI | |
| # # ---------------------- | |
| # with gr.Blocks(title="Video QA β Qwen3-VL & LLaVA-Video", theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown("# π₯ Video Question Answering Demo") | |
| # gr.Markdown("Compare **Qwen3-VL-4B-Instruct** and **LLaVA-Video-7B-Qwen2** on the same videos") | |
| # # TOP SECTION: Video Selection and Display | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # video_id = gr.Dropdown( | |
| # choices=VIDEO_IDS, | |
| # label="π Select Video ID", | |
| # filterable=True, | |
| # interactive=True, | |
| # value=VIDEO_IDS[0] if VIDEO_IDS else None | |
| # ) | |
| # video_label = gr.Markdown( | |
| # value=get_video_label(VIDEO_IDS[0]) if VIDEO_IDS else "", | |
| # label="Video Information" | |
| # ) | |
| # fps_slider = gr.Slider( | |
| # minimum=0.5, | |
| # maximum=5.0, | |
| # step=0.5, | |
| # value=DEFAULT_FPS, | |
| # label="ποΈ Frames Per Second (FPS)", | |
| # info="Higher FPS = more frames analyzed (slower but more detailed)" | |
| # ) | |
| # with gr.Column(scale=2): | |
| # video_player = gr.Video( | |
| # label="Selected Video", | |
| # autoplay=False, | |
| # height=360, | |
| # value=get_video_path(VIDEO_IDS[0]) if VIDEO_IDS else None | |
| # ) | |
| # gr.Markdown("---") | |
| # # BOTTOM SECTION: Two Models Side by Side | |
| # with gr.Row(): | |
| # # QWEN COLUMN | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("### π€ Qwen3-VL-4B-Instruct") | |
| # qwen_prompt = gr.Textbox( | |
| # label="Prompt", | |
| # placeholder="Ask a question about the video...", | |
| # lines=4, | |
| # value="Describe what is happening in this video." | |
| # ) | |
| # qwen_answer = gr.Textbox( | |
| # label="Qwen Answer", | |
| # lines=10, | |
| # interactive=False | |
| # ) | |
| # qwen_run = gr.Button("π Run Qwen Inference", variant="primary") | |
| # # LLAVA COLUMN | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("### π¬ LLaVA-Video-7B-Qwen2") | |
| # llava_prompt = gr.Textbox( | |
| # label="Prompt", | |
| # placeholder="Ask a question about the video...", | |
| # lines=4, | |
| # value="Describe what is happening in this video." | |
| # ) | |
| # llava_answer = gr.Textbox( | |
| # label="LLaVA Answer", | |
| # lines=10, | |
| # interactive=False | |
| # ) | |
| # llava_run = gr.Button("π Run LLaVA Inference", variant="primary") | |
| # # Model info footer | |
| # gr.Markdown(""" | |
| # --- | |
| # **Model Information:** | |
| # - **Qwen3-VL-4B-Instruct**: 4B parameter vision-language model | |
| # - **LLaVA-Video-7B-Qwen2**: 7B parameter video understanding model | |
| # **Settings:** Max Tokens={}, Temperature={} | |
| # """.format(MAX_NEW_TOKENS, TEMPERATURE)) | |
| # # ---------------------- | |
| # # Event Handlers | |
| # # ---------------------- | |
| # # Update video player and label when dropdown changes | |
| # video_id.change( | |
| # fn=update_video_info, | |
| # inputs=video_id, | |
| # outputs=[video_player, video_label] | |
| # ) | |
| # # Run Qwen inference | |
| # qwen_run.click( | |
| # fn=qwen_inference, | |
| # inputs=[video_id, qwen_prompt, fps_slider], | |
| # outputs=qwen_answer | |
| # ) | |
| # # Run LLaVA inference | |
| # llava_run.click( | |
| # fn=llava_inference, | |
| # inputs=[video_id, llava_prompt, fps_slider], | |
| # outputs=llava_answer | |
| # ) | |
| # # Launch | |
| # demo.launch( | |
| # server_name="0.0.0.0", | |
| # server_port=7860, | |
| # share=True | |
| # ) |