| | import os |
| | import json |
| | import re |
| | import torch |
| | from tqdm import tqdm |
| | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
| | from qwen_vl_utils import process_vision_info |
| |
|
| | class VLMModel: |
| | def __init__(self, model_path, device='auto'): |
| | print(f"Loading model weights from: {model_path}") |
| | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.bfloat16, |
| | attn_implementation="flash_attention_2", |
| | device_map=device |
| | ) |
| | self.processor = AutoProcessor.from_pretrained(model_path) |
| |
|
| | def infer(self, image_path, problem): |
| | if not os.path.exists(image_path): |
| | return "IMAGE_FILE_NOT_FOUND" |
| |
|
| | messages = [{ |
| | "role": "user", |
| | "content": [{"type": "image", "image": image_path}, |
| | {"type": "text", "text": problem}] |
| | }] |
| |
|
| | text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| |
|
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, |
| | padding=True, return_tensors="pt").to(self.model.device) |
| |
|
| | generated_ids = self.model.generate(**inputs, max_new_tokens=256, use_cache=True) |
| | generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] |
| | output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False)[0] |
| |
|
| | answer_match = re.search(r'<answer>(.*?)</answer>', output_text, re.DOTALL) |
| | return answer_match.group(1).strip() if answer_match else "NO_ANSWER_FOUND" |
| |
|
| | class Evaluator: |
| | def __init__(self, model, test_data_path): |
| | self.model = model |
| | self.test_data_path = test_data_path |
| |
|
| | def load_data(self): |
| | with open(self.test_data_path, 'r', encoding='utf-8') as f: |
| | return json.load(f) |
| |
|
| | def evaluate(self): |
| | test_data = self.load_data() |
| | results = [] |
| |
|
| | for item in tqdm(test_data): |
| | prediction = self.model.infer(item["image"], item["problem"]) |
| | results.append({ |
| | "image": item["image"], |
| | "problem": item["problem"], |
| | "solution": prediction, |
| | "ground_truth": item["solution"] |
| | }) |
| |
|
| | return self.calculate_metrics(results) |
| |
|
| | @staticmethod |
| | def calculate_metrics(results): |
| | total, correct, defect_total, defect_correct, no_defect_total, no_defect_correct = 0, 0, 0, 0, 0, 0 |
| |
|
| | for res in results: |
| | pred, true = res['solution'], res['ground_truth'] |
| |
|
| | total += 1 |
| | correct += (pred == true) |
| |
|
| | if true == "No Defect": |
| | no_defect_total += 1 |
| | no_defect_correct += (pred == true) |
| | else: |
| | defect_total += 1 |
| | defect_correct += (pred == true) |
| |
|
| | metrics = { |
| | "accuracy": correct / total if total else 0, |
| | "defect_recall": defect_correct / defect_total if defect_total else 0, |
| | "no_defect_recall": no_defect_correct / no_defect_total if no_defect_total else 0, |
| | "total_samples": total |
| | } |
| |
|
| | print("\nEvaluation Results:") |
| | print(json.dumps(metrics, indent=2)) |
| | return metrics |
| |
|
| |
|
| | def main(model_path, test_data_path): |
| | vlm_model = VLMModel(model_path) |
| | evaluator = Evaluator(vlm_model, test_data_path) |
| | evaluator.evaluate() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser(description='Model inference and evaluation script') |
| | parser.add_argument('--model', type=str, required=True, help='Path to model weights') |
| | parser.add_argument('--test_data', type=str, required=True, help='Path to test data (JSON)') |
| | args = parser.parse_args() |
| |
|
| | main(args.model, args.test_data) |