GT-r1 / data&evaluate /model_evaluate.py
konkazzz's picture
Upload 4 files
691a4e2 verified
import os
import json
import re
import torch
from tqdm import tqdm
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class VLMModel:
def __init__(self, model_path, device='auto'):
print(f"Loading model weights from: {model_path}")
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=device
)
self.processor = AutoProcessor.from_pretrained(model_path)
def infer(self, image_path, problem):
if not os.path.exists(image_path):
return "IMAGE_FILE_NOT_FOUND"
messages = [{
"role": "user",
"content": [{"type": "image", "image": image_path},
{"type": "text", "text": problem}]
}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs,
padding=True, return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=256, use_cache=True)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
answer_match = re.search(r'<answer>(.*?)</answer>', output_text, re.DOTALL)
return answer_match.group(1).strip() if answer_match else "NO_ANSWER_FOUND"
class Evaluator:
def __init__(self, model, test_data_path):
self.model = model
self.test_data_path = test_data_path
def load_data(self):
with open(self.test_data_path, 'r', encoding='utf-8') as f:
return json.load(f)
def evaluate(self):
test_data = self.load_data()
results = []
for item in tqdm(test_data):
prediction = self.model.infer(item["image"], item["problem"])
results.append({
"image": item["image"],
"problem": item["problem"],
"solution": prediction,
"ground_truth": item["solution"]
})
return self.calculate_metrics(results)
@staticmethod
def calculate_metrics(results):
total, correct, defect_total, defect_correct, no_defect_total, no_defect_correct = 0, 0, 0, 0, 0, 0
for res in results:
pred, true = res['solution'], res['ground_truth']
total += 1
correct += (pred == true)
if true == "No Defect":
no_defect_total += 1
no_defect_correct += (pred == true)
else:
defect_total += 1
defect_correct += (pred == true)
metrics = {
"accuracy": correct / total if total else 0,
"defect_recall": defect_correct / defect_total if defect_total else 0,
"no_defect_recall": no_defect_correct / no_defect_total if no_defect_total else 0,
"total_samples": total
}
print("\nEvaluation Results:")
print(json.dumps(metrics, indent=2))
return metrics
def main(model_path, test_data_path):
vlm_model = VLMModel(model_path)
evaluator = Evaluator(vlm_model, test_data_path)
evaluator.evaluate()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Model inference and evaluation script')
parser.add_argument('--model', type=str, required=True, help='Path to model weights')
parser.add_argument('--test_data', type=str, required=True, help='Path to test data (JSON)')
args = parser.parse_args()
main(args.model, args.test_data)