|
|
""" |
|
|
Custom Inference Handler for Hugging Face Inference Endpoints |
|
|
Combines Qwen2.5-VL embedding extraction + MLP classifiers |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
from typing import Dict, Any |
|
|
import av |
|
|
import tempfile |
|
|
|
|
|
|
|
|
class MLPClassifier(nn.Module): |
|
|
"""MLP classifier matching training architecture""" |
|
|
def __init__(self, input_dim, hidden_dim=512, num_classes=4, dropout=0.3): |
|
|
super(MLPClassifier, self).__init__() |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.classifier(x) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for HF Inference Endpoints |
|
|
""" |
|
|
def __init__(self, path: str): |
|
|
""" |
|
|
Initialize the handler |
|
|
Args: |
|
|
path: Path to the model directory on HF Hub |
|
|
""" |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
print("Loading Qwen2.5-VL model...") |
|
|
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") |
|
|
self.vision_model = AutoModelForVision2Seq.from_pretrained( |
|
|
"Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None |
|
|
) |
|
|
self.vision_model.eval() |
|
|
|
|
|
|
|
|
self.categories = ["Boredom", "Engagement", "Confusion", "Frustration"] |
|
|
self.classifiers = {} |
|
|
|
|
|
path = Path(path) |
|
|
classifiers_dir = path / "classifiers" |
|
|
|
|
|
print("Loading MLP classifiers...") |
|
|
for category in self.categories: |
|
|
checkpoint_path = classifiers_dir / f"mlp_{category}_best.pth" |
|
|
|
|
|
if checkpoint_path.exists(): |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
|
|
|
first_layer_weight = checkpoint['model_state_dict']['classifier.0.weight'] |
|
|
input_dim = first_layer_weight.shape[1] |
|
|
|
|
|
|
|
|
model = MLPClassifier(input_dim=input_dim, num_classes=4) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
self.classifiers[category] = model |
|
|
print(f" ✓ Loaded {category} classifier") |
|
|
else: |
|
|
print(f" ✗ Missing {category} classifier at {checkpoint_path}") |
|
|
|
|
|
self.fps = 1 |
|
|
|
|
|
def extract_image_embeddings(self, image_path: str) -> np.ndarray: |
|
|
"""Extract embeddings from a single image using Qwen model""" |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": "Analyze this image."} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
text = self.processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
inputs = self.processor( |
|
|
text=[text], |
|
|
images=[image], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
) |
|
|
|
|
|
inputs = {k: v.to(self.vision_model.device) for k, v in inputs.items()} |
|
|
outputs = self.vision_model(**inputs, output_hidden_states=True) |
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
embeddings = hidden_states.mean(dim=1).squeeze(0) |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def extract_video_embeddings(self, video_path: str) -> np.ndarray: |
|
|
"""Extract embeddings from video using Qwen model""" |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "video", "video": str(video_path)}, |
|
|
{"type": "text", "text": "Analyze this video."} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = self.processor.apply_chat_template( |
|
|
messages, |
|
|
fps=self.fps, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = {k: v.to(self.vision_model.device) for k, v in inputs.items()} |
|
|
outputs = self.vision_model(**inputs, output_hidden_states=True) |
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
embeddings = hidden_states.mean(dim=1).squeeze(0) |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Handle inference request |
|
|
|
|
|
Args: |
|
|
data: Input data containing either: |
|
|
- "inputs": base64 encoded image or video file |
|
|
- "video_url": URL to video file |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions for each emotion category |
|
|
""" |
|
|
try: |
|
|
import base64 |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
file_path = None |
|
|
is_image = False |
|
|
|
|
|
|
|
|
if "inputs" in data: |
|
|
|
|
|
input_data = data["inputs"] |
|
|
|
|
|
|
|
|
if ',' in input_data and input_data.startswith('data:'): |
|
|
input_data = input_data.split(',', 1)[1] |
|
|
|
|
|
file_bytes = base64.b64decode(input_data) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
Image.open(io.BytesIO(file_bytes)) |
|
|
is_image = True |
|
|
suffix = '.png' |
|
|
except: |
|
|
|
|
|
is_image = False |
|
|
suffix = '.avi' |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: |
|
|
tmp.write(file_bytes) |
|
|
file_path = tmp.name |
|
|
|
|
|
elif "video_url" in data: |
|
|
|
|
|
import requests |
|
|
response = requests.get(data["video_url"]) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.avi', delete=False) as tmp: |
|
|
tmp.write(response.content) |
|
|
file_path = tmp.name |
|
|
is_image = False |
|
|
else: |
|
|
return {"error": "No input provided. Use 'inputs' (base64) or 'video_url'"} |
|
|
|
|
|
|
|
|
if is_image: |
|
|
embeddings = self.extract_image_embeddings(file_path) |
|
|
else: |
|
|
embeddings = self.extract_video_embeddings(file_path) |
|
|
|
|
|
embeddings_tensor = torch.FloatTensor(embeddings).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
predictions = {} |
|
|
|
|
|
with torch.no_grad(): |
|
|
for category, model in self.classifiers.items(): |
|
|
outputs = model(embeddings_tensor) |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
predicted_level = outputs.argmax(dim=1).item() |
|
|
confidence = probabilities[0][predicted_level].item() |
|
|
|
|
|
predictions[category] = { |
|
|
"level": int(predicted_level), |
|
|
"confidence": float(confidence), |
|
|
"probabilities": probabilities[0].cpu().numpy().tolist() |
|
|
} |
|
|
|
|
|
|
|
|
if file_path: |
|
|
Path(file_path).unlink(missing_ok=True) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"predictions": predictions |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
} |
|
|
|