rahul7star's picture
Update app.py
4ea68ce verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import cv2
import numpy as np
import logging
# ---------------- Logging Setup ----------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
MID = "apple/FastVLM-7B"
IMAGE_TOKEN_INDEX = -200
tok = None
model = None
# ---------------- Load Model ----------------
def load_model():
global tok, model
if tok is None or model is None:
logging.info("Loading FastVLM model (CPU only)...")
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MID,
torch_dtype=torch.float32, # ✅ CPU-friendly
device_map="cpu", # ✅ Force CPU
trust_remote_code=True,
)
logging.info("✅ Model loaded successfully on CPU")
return tok, model
# ---------------- Frame Extraction ----------------
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
logging.info(f"Extracting up to {num_frames} frames using '{sampling_method}' sampling")
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
logging.info(f"Total frames in video: {total_frames}")
if total_frames == 0:
cap.release()
logging.warning("⚠️ No frames found in video")
return []
frames = []
if sampling_method == "uniform":
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
elif sampling_method == "first":
indices = list(range(min(num_frames, total_frames)))
elif sampling_method == "last":
start = max(0, total_frames - num_frames)
indices = list(range(start, total_frames))
else: # middle
start = max(0, (total_frames - num_frames) // 2)
indices = list(range(start, min(start + num_frames, total_frames)))
logging.info(f"Selected frame indices: {indices}")
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
logging.info(f"✅ Extracted frame {idx}")
else:
logging.warning(f"⚠️ Failed to extract frame {idx}")
cap.release()
return frames
# ---------------- Caption Frame ----------------
def caption_frame(image: Image.Image, prompt: str) -> str:
tok, model = load_model()
logging.info(f"Captioning frame with prompt: {prompt!r}")
messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("<image>", 1)
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
attention_mask = torch.ones_like(input_ids)
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
with torch.no_grad():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=15,
temperature=0.7,
do_sample=True,
)
raw_output = tok.decode(out[0], skip_special_tokens=True)
logging.info(f"Raw model output: {raw_output!r}")
caption = raw_output
if prompt in caption:
caption = caption.split(prompt)[-1].strip()
logging.info(f"✅ Final cleaned caption: {caption!r}")
return caption
# ---------------- Process Video ----------------
def process_video(video_path, num_frames, sampling_method, chat_history, progress=gr.Progress()):
if not video_path:
chat_history.append(["Assistant", "Please upload a video first."])
logging.warning("No video uploaded")
return chat_history, None
logging.info(f"Starting analysis of video: {video_path}")
progress(0, desc="Extracting frames...")
frames = extract_frames(video_path, num_frames, sampling_method)
if not frames:
chat_history.append(["Assistant", "Failed to extract frames."])
logging.error("No frames extracted")
return chat_history, None
prompt = "Provide a brief one-sentence description of what's happening in this image."
captions = []
chat_history.append(["Assistant", "Analyzing frames..."])
for i, frame in enumerate(frames):
caption = caption_frame(frame, prompt)
captions.append(f"Frame {i+1}: {caption}")
chat_history[-1] = ["Assistant", "\n".join(captions)]
progress((i + 1) / len(frames))
logging.info(f"Progress: frame {i+1}/{len(frames)} analyzed")
final_summary = "\n".join(captions)
logging.info("✅ Video analysis complete")
logging.info(f"Final summary:\n{final_summary}")
progress(1.0, desc="Analysis complete!")
return chat_history, frames
# ---------------- Gradio UI ----------------
class AppleTheme(gr.themes.Base):
def __init__(self):
super().__init__(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.gray,
)
with gr.Blocks(theme=AppleTheme()) as demo:
gr.Markdown("# 🎬 FastVLM Video Captioning (CPU Only, with Logs)")
with gr.Row():
with gr.Column(scale=7):
video_display = gr.Video(label="Video Input", autoplay=True, loop=True)
with gr.Sidebar(width=400):
chatbot = gr.Chatbot(
value=[["Assistant", "Upload a video and I'll analyze it for you!"]],
height=400
)
process_btn = gr.Button("🎯 Analyze Video", variant="primary")
with gr.Accordion("🖼️ Analyzed Frames", open=False):
frame_gallery = gr.Gallery(columns=2, rows=4, height="auto")
num_frames = gr.State(value=4)
sampling_method = gr.State(value="uniform")
process_btn.click(
fn=process_video,
inputs=[video_display, num_frames, sampling_method, chatbot],
outputs=[chatbot, frame_gallery],
show_progress=True
)
# ---------------- Launch ----------------
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)