jena-shreyas's picture
Update app.py
079a5b2 verified
import os
import sys
from pathlib import Path
import gc
import torch
import gradio as gr
# Allow importing your models package
sys.path.insert(0, str(Path(__file__).parent))
from models import load_model
from models.base import BaseVideoModel
# ----------------------
# CONFIG
# ----------------------
DEVICE_MAP = "cuda:0"
VIDEO_DIR = str(Path(__file__).parent / "videos")
FPS = 1.0
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.01
# ----------------------
# Model loading with quantization support
# ----------------------
model: BaseVideoModel = None
current_model_name = "Qwen3-VL-4B-Instruct"
current_quantization = "16-bit"
def load_model_with_quantization(
model_name: str,
quantization: str
):
"""Load or reload the model with specified quantization"""
global model, current_model_name, current_quantization
# Free GPU memory if model already exists
if model is not None:
print("Unloading existing model and freeing GPU memory...")
del model
gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared.")
load_8bit = False
load_4bit = False
if quantization == "8-bit":
load_8bit = True
elif quantization == "4-bit":
load_4bit = True
# else: 16-bit (normal) - both flags remain False
print(f"Loading {model_name} with {quantization} quantization...")
model_path = model_name
# Load the HF version of LLaVA-Video-7B instead of the default version, for transformers v5 compatibility
# For the Qwen models, load the model from the Qwen directory
if model_name == "LLaVA-Video-7B-Qwen2":
model_path = "Isotr0py/LLaVA-Video-7B-Qwen2-hf"
elif model_name.startswith("Qwen"):
model_path = f"Qwen/{model_name}"
model = load_model(
model_path,
device_map=DEVICE_MAP,
load_8bit=load_8bit,
load_4bit=load_4bit,
)
current_model_name = model_name
current_quantization = quantization
print(f"{model_name} loaded with {quantization} quantization.")
return f"βœ… {model_name} loaded successfully with {quantization} quantization"
# Load model initially with 16-bit (normal)
load_model_with_quantization(current_model_name, current_quantization)
# ----------------------
# Collect video IDs
# ----------------------
VIDEO_IDS = sorted([
os.path.splitext(f)[0]
for f in os.listdir(VIDEO_DIR)
if f.endswith(".mp4")
])
# ----------------------
# Helpers
# ----------------------
def get_video_path(video_id: str):
if not video_id:
return None
path = os.path.join(VIDEO_DIR, video_id + ".mp4")
return path if os.path.exists(path) else None
# ----------------------
# Inference function
# ----------------------
def video_qa(
video_id: str,
prompt: str,
video_mode: str,
fps: float,
num_frames: int,
max_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> str:
if not video_id:
return "❌ Please select a video ID."
if not prompt.strip():
return "❌ Please enter a prompt."
video_path = get_video_path(video_id)
if video_path is None:
return f"❌ Video not found: {video_id}.mp4"
try:
# Prepare generation config
generation_config = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
# Add video_mode if supported by the model
kwargs = {
"prompt": prompt,
"video_path": video_path,
"fps": fps,
"num_frames": num_frames,
**generation_config
}
# Try to add video_mode (for Qwen models)
try:
response = model.chat(**kwargs, video_mode=video_mode)
except TypeError:
# If video_mode is not supported, fall back to without it
response = model.chat(**kwargs)
return response
except Exception as e:
return f"❌ Error during inference: {str(e)}"
# ----------------------
# Gradio UI
# ----------------------
with gr.Blocks(title="Video Inference Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("## πŸŽ₯ Video Inference")
with gr.Row():
# LEFT COLUMN
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Video Selection")
video_id = gr.Dropdown(
choices=VIDEO_IDS,
label="Video ID",
filterable=True,
interactive=True,
value=VIDEO_IDS[0] if VIDEO_IDS else None
)
video_player = gr.Video(
label="Selected Video",
autoplay=False,
height=300
)
gr.Markdown("### πŸ€– Model Name")
model_name_radio = gr.Radio(
choices=[
"Qwen3-VL-4B-Instruct",
"Qwen3-VL-8B-Instruct",
"Qwen3-VL-2B-Thinking",
"Qwen3-VL-4B-Thinking",
"LLaVA-Video-7B-Qwen2"
],
value="Qwen3-VL-4B-Instruct",
label="πŸ€– Model Name",
info="Select the model to use for inference"
)
gr.Markdown("### βš™οΈ Model Parameters")
quantization_radio = gr.Radio(
choices=["16-bit", "8-bit", "4-bit"],
value="16-bit",
label="πŸ”§ Model Quantization",
info="16-bit: Default precision, 8-bit/4-bit: Reduced memory usage"
)
reload_button = gr.Button("πŸ”„ Reload Model", variant="secondary")
reload_status = gr.Textbox(
label="Model Status",
value=f"{current_model_name} loaded with {current_quantization} quantization",
interactive=False,
lines=1
)
fps_slider = gr.Slider(
minimum=0.5,
maximum=10.0,
step=0.5,
value=FPS,
label="🎞️ Frames Per Second (FPS)",
info="Sample rate for video frames"
)
video_mode_radio = gr.Radio(
choices=["video", "frames"],
value="video",
label="πŸ“Ή Video Mode",
info="'video' for FPS-based, 'frames' for fixed count"
)
num_frames_slider = gr.Slider(
minimum=1,
maximum=30,
step=1,
value=8,
label="πŸ–ΌοΈ Number of Frames",
info="Fixed frame count (used when video_mode='frames')"
)
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=128,
maximum=2048,
step=128,
value=MAX_NEW_TOKENS,
label="Max New Tokens",
info="Maximum length of generated response"
)
temperature_slider = gr.Slider(
minimum=0.01,
maximum=2.0,
step=0.01,
value=TEMPERATURE,
label="🌑️ Temperature",
info="Higher = more creative, lower = more focused"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="πŸ” Top-K",
info="Sample from top K tokens"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.95,
label="🎯 Top-P (Nucleus)",
info="Cumulative probability threshold"
)
# RIGHT COLUMN
with gr.Column(scale=2):
gr.Markdown("### πŸ’¬ Question & Answer")
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask a question about the selected video...",
lines=4,
value="Describe what is happening in this video."
)
answer = gr.Textbox(
label="Model Answer",
lines=20,
interactive=False
)
run = gr.Button("πŸš€ Run Inference", variant="primary", size="lg")
gr.Markdown("""
---
**ℹ️ Tips:**
- **Quantization:** 16-bit (full precision), 8-bit (2x memory savings), 4-bit (4x memory savings with slight quality loss)
- Adjust FPS to control video sampling rate (higher = more frames, slower inference)
- Use video_mode='frames' for fixed frame count (useful for very long videos)
- Temperature: Lower (0.01-0.5) for factual, higher (0.7-1.5) for creative responses
- Top-K and Top-P control output diversity
""")
# Update video player when dropdown changes
video_id.change(
fn=get_video_path,
inputs=video_id,
outputs=video_player
)
# Reload model with new quantization
reload_button.click(
fn=load_model_with_quantization,
inputs=[
model_name_radio,
quantization_radio,
],
outputs=reload_status
)
# Run inference
run.click(
fn=video_qa,
inputs=[
video_id,
prompt,
video_mode_radio,
fps_slider,
num_frames_slider,
max_tokens_slider,
temperature_slider,
top_k_slider,
top_p_slider,
],
outputs=answer
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
# #---------------
# #---------------
# #---------------
# # Feb 5, 2026
# #---------------
# import os
# import sys
# import json
# from pathlib import Path
# import gradio as gr
# # Allow importing your models package
# sys.path.insert(0, str(Path(__file__).parent))
# from models import load_model
# from models.base import BaseVideoModel
# # ----------------------
# # CONFIG
# # ----------------------
# QWEN_MODEL_PATH = "Qwen/Qwen3-VL-4B-Instruct"
# LLAVA_MODEL_PATH = "lmms-lab/LLaVA-Video-7B-Qwen2"
# DEVICE_MAP_QWEN = "cuda:0"
# DEVICE_MAP_LLAVA = "cuda:0" # Both models on same GPU
# VIDEO_DIR = "/home/raman/Gradio_Qwen3vl4bInstruct/videos"
# LABELS_JSON = "/home/raman/Gradio_Qwen3vl4bInstruct/SSv2_prepost_sampled.json"
# DEFAULT_FPS = 1.0
# MAX_NEW_TOKENS = 512
# TEMPERATURE = 0.01
# # ----------------------
# # Load video labels
# # ----------------------
# print("Loading video labels...")
# video_labels = {}
# try:
# with open(LABELS_JSON, 'r') as f:
# labels_data = json.load(f)
# for item in labels_data:
# video_labels[item['id']] = {
# 'label': item['label'],
# 'template': item.get('template', ''),
# 'action_group': item.get('action_group', '')
# }
# print(f"Loaded {len(video_labels)} video labels.")
# except Exception as e:
# print(f"Warning: Could not load labels JSON: {e}")
# # ----------------------
# # Load models
# # ----------------------
# print("Loading Qwen3-VL-4B-Instruct...")
# qwen_model: BaseVideoModel = load_model(
# QWEN_MODEL_PATH,
# device_map=DEVICE_MAP_QWEN,
# )
# print("Qwen model loaded.")
# print("Loading LLaVA-Video-7B...")
# llava_model: BaseVideoModel = load_model(
# LLAVA_MODEL_PATH,
# device_map=DEVICE_MAP_LLAVA,
# )
# print("LLaVA model loaded.")
# # ----------------------
# # Collect video IDs
# # ----------------------
# VIDEO_IDS = sorted([
# os.path.splitext(f)[0]
# for f in os.listdir(VIDEO_DIR)
# if f.endswith(".mp4")
# ])
# print(f"Found {len(VIDEO_IDS)} videos.")
# # ----------------------
# # Helpers
# # ----------------------
# def get_video_path(video_id: str):
# if not video_id:
# return None
# path = os.path.join(VIDEO_DIR, video_id + ".mp4")
# return path if os.path.exists(path) else None
# def get_video_label(video_id: str):
# if not video_id:
# return ""
# info = video_labels.get(video_id, {})
# label = info.get('label', 'No label available')
# action_group = info.get('action_group', '')
# if action_group:
# return f"**Label:** {label}\n\n**Action Group:** {action_group}"
# return f"**Label:** {label}"
# def update_video_info(video_id: str):
# """Returns video path and label when video is selected"""
# video_path = get_video_path(video_id)
# label = get_video_label(video_id)
# return video_path, label
# # ----------------------
# # Inference functions
# # ----------------------
# def qwen_inference(video_id: str, prompt: str, fps: float) -> str:
# if not video_id:
# return "❌ Please select a video ID."
# if not prompt.strip():
# return "❌ Please enter a prompt."
# video_path = get_video_path(video_id)
# if video_path is None:
# return f"❌ Video not found: {video_id}.mp4"
# try:
# response = qwen_model.chat(
# prompt=prompt,
# video_path=video_path,
# fps=fps,
# max_new_tokens=MAX_NEW_TOKENS,
# temperature=TEMPERATURE,
# )
# return response
# except Exception as e:
# return f"❌ Error during Qwen inference: {str(e)}"
# def llava_inference(video_id: str, prompt: str, fps: float) -> str:
# if not video_id:
# return "❌ Please select a video ID."
# if not prompt.strip():
# return "❌ Please enter a prompt."
# video_path = get_video_path(video_id)
# if video_path is None:
# return f"❌ Video not found: {video_id}.mp4"
# try:
# response = llava_model.chat(
# prompt=prompt,
# video_path=video_path,
# fps=fps,
# max_new_tokens=MAX_NEW_TOKENS,
# temperature=TEMPERATURE,
# )
# return response
# except Exception as e:
# return f"❌ Error during LLaVA inference: {str(e)}"
# # ----------------------
# # Gradio UI
# # ----------------------
# with gr.Blocks(title="Video QA – Qwen3-VL & LLaVA-Video", theme=gr.themes.Soft()) as demo:
# gr.Markdown("# πŸŽ₯ Video Question Answering Demo")
# gr.Markdown("Compare **Qwen3-VL-4B-Instruct** and **LLaVA-Video-7B-Qwen2** on the same videos")
# # TOP SECTION: Video Selection and Display
# with gr.Row():
# with gr.Column(scale=1):
# video_id = gr.Dropdown(
# choices=VIDEO_IDS,
# label="πŸ“ Select Video ID",
# filterable=True,
# interactive=True,
# value=VIDEO_IDS[0] if VIDEO_IDS else None
# )
# video_label = gr.Markdown(
# value=get_video_label(VIDEO_IDS[0]) if VIDEO_IDS else "",
# label="Video Information"
# )
# fps_slider = gr.Slider(
# minimum=0.5,
# maximum=5.0,
# step=0.5,
# value=DEFAULT_FPS,
# label="🎞️ Frames Per Second (FPS)",
# info="Higher FPS = more frames analyzed (slower but more detailed)"
# )
# with gr.Column(scale=2):
# video_player = gr.Video(
# label="Selected Video",
# autoplay=False,
# height=360,
# value=get_video_path(VIDEO_IDS[0]) if VIDEO_IDS else None
# )
# gr.Markdown("---")
# # BOTTOM SECTION: Two Models Side by Side
# with gr.Row():
# # QWEN COLUMN
# with gr.Column(scale=1):
# gr.Markdown("### πŸ€– Qwen3-VL-4B-Instruct")
# qwen_prompt = gr.Textbox(
# label="Prompt",
# placeholder="Ask a question about the video...",
# lines=4,
# value="Describe what is happening in this video."
# )
# qwen_answer = gr.Textbox(
# label="Qwen Answer",
# lines=10,
# interactive=False
# )
# qwen_run = gr.Button("πŸš€ Run Qwen Inference", variant="primary")
# # LLAVA COLUMN
# with gr.Column(scale=1):
# gr.Markdown("### 🎬 LLaVA-Video-7B-Qwen2")
# llava_prompt = gr.Textbox(
# label="Prompt",
# placeholder="Ask a question about the video...",
# lines=4,
# value="Describe what is happening in this video."
# )
# llava_answer = gr.Textbox(
# label="LLaVA Answer",
# lines=10,
# interactive=False
# )
# llava_run = gr.Button("πŸš€ Run LLaVA Inference", variant="primary")
# # Model info footer
# gr.Markdown("""
# ---
# **Model Information:**
# - **Qwen3-VL-4B-Instruct**: 4B parameter vision-language model
# - **LLaVA-Video-7B-Qwen2**: 7B parameter video understanding model
# **Settings:** Max Tokens={}, Temperature={}
# """.format(MAX_NEW_TOKENS, TEMPERATURE))
# # ----------------------
# # Event Handlers
# # ----------------------
# # Update video player and label when dropdown changes
# video_id.change(
# fn=update_video_info,
# inputs=video_id,
# outputs=[video_player, video_label]
# )
# # Run Qwen inference
# qwen_run.click(
# fn=qwen_inference,
# inputs=[video_id, qwen_prompt, fps_slider],
# outputs=qwen_answer
# )
# # Run LLaVA inference
# llava_run.click(
# fn=llava_inference,
# inputs=[video_id, llava_prompt, fps_slider],
# outputs=llava_answer
# )
# # Launch
# demo.launch(
# server_name="0.0.0.0",
# server_port=7860,
# share=True
# )