FarhanAkhtar's picture
Upload score.py
8f0ebf7 verified
import os
import json
import base64
import io
import torch
import logging
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from qwen_vl_utils import process_vision_info
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def init():
"""
Initialize the Qwen3-VL model and processor for CPU inference
"""
global model, processor, device
# Always use CPU
device = "cpu"
logger.info(f"Using device: {device}")
# Model path from AZUREML_MODEL_DIR
base_model_dir = os.environ.get("AZUREML_MODEL_DIR", ".")
model_path = os.path.join(base_model_dir, "Vision qwen3 Model")
logger.info(f"Loading Qwen3-VL model from {model_path}")
# Load processor with trust_remote_code=True
logger.info("Loading processor...")
processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True
)
# Load model with trust_remote_code=True
# REMOVE device_map for CPU, use .to(device) instead
logger.info("Loading model (fp32) on CPU...")
model = AutoModelForImageTextToText.from_pretrained(
model_path,
dtype=torch.float32, # ← Changed from torch_dtype
low_cpu_mem_usage=True,
trust_remote_code=True
# ← REMOVED device_map="cpu"
)
model.to(device) # ← Move model to CPU explicitly
model.eval()
torch.set_grad_enabled(False)
logger.info("Model and processor loaded successfully on CPU")
def run(raw_data):
"""
Process incoming requests
"""
try:
# Parse input
data = json.loads(raw_data)
# Extract image and prompt
image_data = data.get("image")
prompt = data.get("prompt", "Describe this image.")
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process vision info
image_inputs, video_inputs = process_vision_info(
messages,
image_patch_size=16
)
# Prepare inputs (do_resize=False since qwen-vl-utils already resized)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
do_resize=False
)
inputs = inputs.to(device)
# Generate response
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512
)
# Trim and decode
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return json.dumps({"result": output_text[0]})
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return json.dumps({"error": str(e)})