import os import json import re import torch from tqdm import tqdm from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info class VLMModel: def __init__(self, model_path, device='auto'): print(f"Loading model weights from: {model_path}") self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map=device ) self.processor = AutoProcessor.from_pretrained(model_path) def infer(self, image_path, problem): if not os.path.exists(image_path): return "IMAGE_FILE_NOT_FOUND" messages = [{ "role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": problem}] }] text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(self.model.device) generated_ids = self.model.generate(**inputs, max_new_tokens=256, use_cache=True) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] answer_match = re.search(r'(.*?)', output_text, re.DOTALL) return answer_match.group(1).strip() if answer_match else "NO_ANSWER_FOUND" class Evaluator: def __init__(self, model, test_data_path): self.model = model self.test_data_path = test_data_path def load_data(self): with open(self.test_data_path, 'r', encoding='utf-8') as f: return json.load(f) def evaluate(self): test_data = self.load_data() results = [] for item in tqdm(test_data): prediction = self.model.infer(item["image"], item["problem"]) results.append({ "image": item["image"], "problem": item["problem"], "solution": prediction, "ground_truth": item["solution"] }) return self.calculate_metrics(results) @staticmethod def calculate_metrics(results): total, correct, defect_total, defect_correct, no_defect_total, no_defect_correct = 0, 0, 0, 0, 0, 0 for res in results: pred, true = res['solution'], res['ground_truth'] total += 1 correct += (pred == true) if true == "No Defect": no_defect_total += 1 no_defect_correct += (pred == true) else: defect_total += 1 defect_correct += (pred == true) metrics = { "accuracy": correct / total if total else 0, "defect_recall": defect_correct / defect_total if defect_total else 0, "no_defect_recall": no_defect_correct / no_defect_total if no_defect_total else 0, "total_samples": total } print("\nEvaluation Results:") print(json.dumps(metrics, indent=2)) return metrics def main(model_path, test_data_path): vlm_model = VLMModel(model_path) evaluator = Evaluator(vlm_model, test_data_path) evaluator.evaluate() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Model inference and evaluation script') parser.add_argument('--model', type=str, required=True, help='Path to model weights') parser.add_argument('--test_data', type=str, required=True, help='Path to test data (JSON)') args = parser.parse_args() main(args.model, args.test_data)