File size: 6,684 Bytes
6a6a2f0
 
 
 
 
 
4ea68ce
 
 
 
 
 
 
 
6a6a2f0
 
 
 
224eae3
 
 
920c71d
224eae3
 
 
4ea68ce
224eae3
 
 
4ea68ce
4d88196
224eae3
 
4ea68ce
224eae3
6a6a2f0
920c71d
 
6a6a2f0
4ea68ce
6a6a2f0
 
4ea68ce
920c71d
6a6a2f0
 
4ea68ce
6a6a2f0
920c71d
6a6a2f0
 
 
 
 
 
 
 
 
 
 
920c71d
4ea68ce
 
6a6a2f0
 
 
 
 
 
4ea68ce
 
 
920c71d
6a6a2f0
 
 
920c71d
 
6a6a2f0
224eae3
4d88196
4ea68ce
 
4d88196
 
6a6a2f0
4d88196
6a6a2f0
 
4d88196
6a6a2f0
4d88196
 
 
6a6a2f0
4d88196
6a6a2f0
 
 
 
 
9b41a1c
6a6a2f0
 
 
4d88196
4ea68ce
 
 
 
6a6a2f0
 
920c71d
4ea68ce
6a6a2f0
920c71d
 
 
 
 
 
4ea68ce
920c71d
 
4ea68ce
920c71d
 
 
 
 
4ea68ce
920c71d
 
 
 
 
 
 
 
 
 
 
4ea68ce
 
 
 
 
920c71d
 
 
 
 
4ea68ce
920c71d
 
 
 
 
 
 
 
 
 
4ea68ce
920c71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ea68ce
920c71d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import cv2
import numpy as np
import logging

# ---------------- Logging Setup ----------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()]
)

MID = "apple/FastVLM-7B"
IMAGE_TOKEN_INDEX = -200

tok = None
model = None

# ---------------- Load Model ----------------
def load_model():
    global tok, model
    if tok is None or model is None:
        logging.info("Loading FastVLM model (CPU only)...")
        tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            MID,
            torch_dtype=torch.float32,   # ✅ CPU-friendly
            device_map="cpu",            # ✅ Force CPU
            trust_remote_code=True,
        )
        logging.info("✅ Model loaded successfully on CPU")
    return tok, model


# ---------------- Frame Extraction ----------------
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
    logging.info(f"Extracting up to {num_frames} frames using '{sampling_method}' sampling")
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    logging.info(f"Total frames in video: {total_frames}")

    if total_frames == 0:
        cap.release()
        logging.warning("⚠️ No frames found in video")
        return []

    frames = []
    if sampling_method == "uniform":
        indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    elif sampling_method == "first":
        indices = list(range(min(num_frames, total_frames)))
    elif sampling_method == "last":
        start = max(0, total_frames - num_frames)
        indices = list(range(start, total_frames))
    else:  # middle
        start = max(0, (total_frames - num_frames) // 2)
        indices = list(range(start, min(start + num_frames, total_frames)))

    logging.info(f"Selected frame indices: {indices}")

    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame_rgb))
            logging.info(f"✅ Extracted frame {idx}")
        else:
            logging.warning(f"⚠️ Failed to extract frame {idx}")

    cap.release()
    return frames


# ---------------- Caption Frame ----------------
def caption_frame(image: Image.Image, prompt: str) -> str:
    tok, model = load_model()

    logging.info(f"Captioning frame with prompt: {prompt!r}")

    messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
    rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    pre, post = rendered.split("<image>", 1)

    pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
    post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids

    img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
    input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)

    attention_mask = torch.ones_like(input_ids)
    px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]

    with torch.no_grad():
        out = model.generate(
            inputs=input_ids,
            attention_mask=attention_mask,
            images=px,
            max_new_tokens=15,
            temperature=0.7,
            do_sample=True,
        )

    raw_output = tok.decode(out[0], skip_special_tokens=True)
    logging.info(f"Raw model output: {raw_output!r}")

    caption = raw_output
    if prompt in caption:
        caption = caption.split(prompt)[-1].strip()

    logging.info(f"✅ Final cleaned caption: {caption!r}")
    return caption


# ---------------- Process Video ----------------
def process_video(video_path, num_frames, sampling_method, chat_history, progress=gr.Progress()):
    if not video_path:
        chat_history.append(["Assistant", "Please upload a video first."])
        logging.warning("No video uploaded")
        return chat_history, None

    logging.info(f"Starting analysis of video: {video_path}")
    progress(0, desc="Extracting frames...")
    frames = extract_frames(video_path, num_frames, sampling_method)

    if not frames:
        chat_history.append(["Assistant", "Failed to extract frames."])
        logging.error("No frames extracted")
        return chat_history, None

    prompt = "Provide a brief one-sentence description of what's happening in this image."
    captions = []

    chat_history.append(["Assistant", "Analyzing frames..."])
    for i, frame in enumerate(frames):
        caption = caption_frame(frame, prompt)
        captions.append(f"Frame {i+1}: {caption}")
        chat_history[-1] = ["Assistant", "\n".join(captions)]
        progress((i + 1) / len(frames))
        logging.info(f"Progress: frame {i+1}/{len(frames)} analyzed")

    final_summary = "\n".join(captions)
    logging.info("✅ Video analysis complete")
    logging.info(f"Final summary:\n{final_summary}")

    progress(1.0, desc="Analysis complete!")
    return chat_history, frames


# ---------------- Gradio UI ----------------
class AppleTheme(gr.themes.Base):
    def __init__(self):
        super().__init__(
            primary_hue=gr.themes.colors.blue,
            secondary_hue=gr.themes.colors.gray,
            neutral_hue=gr.themes.colors.gray,
        )


with gr.Blocks(theme=AppleTheme()) as demo:
    gr.Markdown("# 🎬 FastVLM Video Captioning (CPU Only, with Logs)")

    with gr.Row():
        with gr.Column(scale=7):
            video_display = gr.Video(label="Video Input", autoplay=True, loop=True)

        with gr.Sidebar(width=400):
            chatbot = gr.Chatbot(
                value=[["Assistant", "Upload a video and I'll analyze it for you!"]],
                height=400
            )
            process_btn = gr.Button("🎯 Analyze Video", variant="primary")

            with gr.Accordion("🖼️ Analyzed Frames", open=False):
                frame_gallery = gr.Gallery(columns=2, rows=4, height="auto")

    num_frames = gr.State(value=4)
    sampling_method = gr.State(value="uniform")

    process_btn.click(
        fn=process_video,
        inputs=[video_display, num_frames, sampling_method, chatbot],
        outputs=[chatbot, frame_gallery],
        show_progress=True
    )


# ---------------- Launch ----------------
demo.launch(
    server_name="0.0.0.0",
    server_port=7860,
    share=False,
    show_error=True
)