ashleshp's picture
Switch transformers
b6192e4
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import os
from pathlib import Path
from typing import Optional, List, Dict
import cv2
from src.interfaces.base import PerceptionEngine
class Qwen2PerceptionEngine(PerceptionEngine):
"""
Hugging Face Native implementation of Qwen2-VL.
Optimized for HF Spaces (CPU/GPU) without requiring slow C++ builds.
"""
def __init__(self):
self.model_id = "Qwen/Qwen2-VL-2B-Instruct"
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self, model_path: Optional[Path] = None) -> None:
"""Loads the model using Transformers."""
if self.model is not None:
return
print(f"Loading Qwen2-VL via Transformers on {self.device}...")
# Load model with float16 if on GPU, else float32/bfloat16 for CPU
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_id,
torch_dtype="auto",
device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(self.model_id)
print("✅ Native Vision Model loaded.")
def analyze_frame(self, frame_path: str, prompt: str) -> str:
"""Runs inference using native transformers pipeline."""
if self.model is None:
self.load_model()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": frame_path},
{"type": "text", "text": prompt},
],
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return output_text
def analyze_video_segment(self, video_path: Path, start_time: float, end_time: float, prompt: str) -> str:
"""Extracts and analyzes a frame."""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS)
middle_time = (start_time + end_time) / 2
frame_id = int(middle_time * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
ret, frame = cap.read()
cap.release()
if not ret: return "Error: Could not read frame."
temp_path = "temp_segment_frame.jpg"
cv2.imwrite(temp_path, frame)
return self.analyze_frame(temp_path, prompt)
def generate_text(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Text-only generation."""
if self.model is None: self.load_model()
inputs = self.processor(text=[prompt], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
# Trim the input prompt from the output
output_text = self.processor.batch_decode(
generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)[0]
return output_text
def chat(self, messages: List[Dict[str, str]]) -> str:
# Simplified chat implementation
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
return self.generate_text(prompt)