qwen-mlp / handler.py
mapotofu40
Fix syntax error in handler
4845d75
"""
Custom Inference Handler for Hugging Face Inference Endpoints
Combines Qwen2.5-VL embedding extraction + MLP classifiers
"""
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModelForVision2Seq
from pathlib import Path
import numpy as np
from typing import Dict, Any
import av
import tempfile
class MLPClassifier(nn.Module):
"""MLP classifier matching training architecture"""
def __init__(self, input_dim, hidden_dim=512, num_classes=4, dropout=0.3):
super(MLPClassifier, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes)
)
def forward(self, x):
return self.classifier(x)
class EndpointHandler:
"""
Custom handler for HF Inference Endpoints
"""
def __init__(self, path: str):
"""
Initialize the handler
Args:
path: Path to the model directory on HF Hub
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load Qwen2.5-VL model for embeddings
print("Loading Qwen2.5-VL model...")
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
self.vision_model = AutoModelForVision2Seq.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
self.vision_model.eval()
# Load MLP classifiers for each emotion
self.categories = ["Boredom", "Engagement", "Confusion", "Frustration"]
self.classifiers = {}
path = Path(path)
classifiers_dir = path / "classifiers"
print("Loading MLP classifiers...")
for category in self.categories:
checkpoint_path = classifiers_dir / f"mlp_{category}_best.pth"
if checkpoint_path.exists():
# Determine embedding dimension from checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Get input dimension from first layer
first_layer_weight = checkpoint['model_state_dict']['classifier.0.weight']
input_dim = first_layer_weight.shape[1]
# Initialize model
model = MLPClassifier(input_dim=input_dim, num_classes=4)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(self.device)
model.eval()
self.classifiers[category] = model
print(f" ✓ Loaded {category} classifier")
else:
print(f" ✗ Missing {category} classifier at {checkpoint_path}")
self.fps = 1 # Frame sampling rate
def extract_image_embeddings(self, image_path: str) -> np.ndarray:
"""Extract embeddings from a single image using Qwen model"""
from PIL import Image
# Load image
image = Image.open(image_path).convert('RGB')
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Analyze this image."}
]
}
]
with torch.no_grad():
text = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
inputs = self.processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.vision_model.device) for k, v in inputs.items()}
outputs = self.vision_model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1]
# Average pooling over sequence dimension
embeddings = hidden_states.mean(dim=1).squeeze(0)
embeddings = embeddings.cpu().numpy()
return embeddings
def extract_video_embeddings(self, video_path: str) -> np.ndarray:
"""Extract embeddings from video using Qwen model"""
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path)},
{"type": "text", "text": "Analyze this video."}
]
}
]
with torch.no_grad():
inputs = self.processor.apply_chat_template(
messages,
fps=self.fps,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
inputs = {k: v.to(self.vision_model.device) for k, v in inputs.items()}
outputs = self.vision_model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1]
# Average pooling over sequence dimension
embeddings = hidden_states.mean(dim=1).squeeze(0)
embeddings = embeddings.cpu().numpy()
return embeddings
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle inference request
Args:
data: Input data containing either:
- "inputs": base64 encoded image or video file
- "video_url": URL to video file
Returns:
Dictionary with predictions for each emotion category
"""
try:
import base64
from PIL import Image
import io
file_path = None
is_image = False
# Handle different input formats
if "inputs" in data:
# Base64 encoded data
input_data = data["inputs"]
# Remove data URL prefix if present (e.g., "data:image/png;base64,")
if ',' in input_data and input_data.startswith('data:'):
input_data = input_data.split(',', 1)[1]
file_bytes = base64.b64decode(input_data)
# Try to detect if it's an image or video
try:
# Try to open as image
Image.open(io.BytesIO(file_bytes))
is_image = True
suffix = '.png'
except:
# Assume it's a video
is_image = False
suffix = '.avi'
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(file_bytes)
file_path = tmp.name
elif "video_url" in data:
# Download from URL
import requests
response = requests.get(data["video_url"])
with tempfile.NamedTemporaryFile(suffix='.avi', delete=False) as tmp:
tmp.write(response.content)
file_path = tmp.name
is_image = False
else:
return {"error": "No input provided. Use 'inputs' (base64) or 'video_url'"}
# Extract embeddings based on input type
if is_image:
embeddings = self.extract_image_embeddings(file_path)
else:
embeddings = self.extract_video_embeddings(file_path)
embeddings_tensor = torch.FloatTensor(embeddings).unsqueeze(0).to(self.device)
# Run classifiers
predictions = {}
with torch.no_grad():
for category, model in self.classifiers.items():
outputs = model(embeddings_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_level = outputs.argmax(dim=1).item()
confidence = probabilities[0][predicted_level].item()
predictions[category] = {
"level": int(predicted_level),
"confidence": float(confidence),
"probabilities": probabilities[0].cpu().numpy().tolist()
}
# Clean up temporary file
if file_path:
Path(file_path).unlink(missing_ok=True)
return {
"success": True,
"predictions": predictions
}
except Exception as e:
return {
"success": False,
"error": str(e)
}