Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,738 Bytes
b5bcf5a 76e1435 b5bcf5a 76e1435 b5bcf5a 6a31985 b5bcf5a 6a31985 45a53c4 76e1435 67d411a eab0adb 45a53c4 76e1435 45a53c4 67d411a 76e1435 45a53c4 76e1435 45a53c4 76e1435 45a53c4 76e1435 e4b23f9 76e1435 e4b23f9 ea5eb99 e4b23f9 ea5eb99 e4b23f9 ea5eb99 e4b23f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
from huggingface_hub import login
from collections.abc import Iterator
from transformers import (
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
Gemma3Processor,
)
import spaces
import tempfile
from threading import Thread
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
import cv2
from loguru import logger
from PIL import Image
dotenv_path = find_dotenv()
load_dotenv(dotenv_path)
model_id = os.getenv("MODEL_ID", "google/gemma-3-4b-it")
input_processor = Gemma3Processor.from_pretrained(model_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
frames: list[tuple[Image.Image, float]] = []
capture = cv2.VideoCapture(video_path)
if not capture.isOpened():
raise ValueError(f"Could not open video file: {video_path}")
fps = capture.get(cv2.CAP_PROP_FPS)
total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = max(total_frames // max_images, 1)
max_position = min(total_frames, max_images * frame_interval)
i = 0
while i < max_position and len(frames) < max_images:
capture.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = capture.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
i += frame_interval
capture.release()
return frames
def process_video(video_path: str, max_images: int) -> list[dict]:
result_content = []
# TODO: Change max_image to slider
frames = get_frames(video_path, max_images)
for frame in frames:
image, timestamp = frame
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image.save(temp_file.name)
result_content.append({"type": "text", "text": f"Frame {timestamp}:"})
result_content.append({"type": "image", "url": temp_file.name})
logger.debug(
f"Processed {len(frames)} frames from video {video_path} with frames {result_content}"
)
return result_content
def process_user_input(message: dict, max_images: int) -> list[dict]:
if not message["files"]:
return [{"type": "text", "text": message["text"]}]
if message["files"][0].endswith(".mp4"):
return [
{"type": "text", "text": message["text"]},
*process_video(message["files"][0], max_images),
]
return [
{"type": "text", "text": message["text"]},
*[{"type": "image", "url": path} for path in message["files"]],
]
def process_history(history: list[dict]) -> list[dict]:
messages = []
user_content_buffer = []
for item in history:
if item["role"] == "assistant":
if user_content_buffer:
messages.append({"role": "user", "content": user_content_buffer})
user_content_buffer = []
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": item["content"]}],
}
)
else:
content = item["content"]
user_content_buffer.append(
{"type": "text", "text": content}
if isinstance(content, str)
else {"type": "image", "url": content[0]}
)
if user_content_buffer:
messages.append({"role": "user", "content": user_content_buffer})
return messages
|