| import argparse |
| import json |
| import requests |
| import base64 |
| from PIL import Image |
| from io import BytesIO |
| from llava.conversation import conv_templates |
| import time |
| import os |
| import glob |
| import logging |
| from datetime import datetime |
| from tqdm import tqdm |
| import re |
| from typing import Dict, List, Optional, Union, Any, Tuple |
|
|
|
|
| def process_image(image_path: str, target_size: int = 640) -> Image.Image: |
| """Process and resize an image to match model requirements. |
| |
| Args: |
| image_path: Path to the input image file |
| target_size: Target size for both width and height in pixels |
| |
| Returns: |
| PIL.Image: Processed and padded image with dimensions (target_size, target_size) |
| """ |
| image = Image.open(image_path) |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| ratio = min(target_size / image.width, target_size / image.height) |
| new_size = (int(image.width * ratio), int(image.height * ratio)) |
|
|
| |
| image = image.resize(new_size, Image.LANCZOS) |
|
|
| |
| new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0)) |
| |
| offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2) |
| new_image.paste(image, offset) |
|
|
| return new_image |
|
|
|
|
| def validate_answer(response_text: str) -> Optional[str]: |
| """Extract and validate a single-letter response from the model's output. |
| Handles multiple response formats and edge cases. |
| |
| Args: |
| response_text: The full text output from the model |
| |
| Returns: |
| A single letter answer (A-F) or None if no valid answer found |
| """ |
| if not response_text: |
| return None |
|
|
| |
| cleaned = response_text.strip() |
|
|
| |
| extraction_patterns = [ |
| |
| r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b", |
| |
| r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", |
| r"\b(?:answer|option)\s*([A-F])[):]\s*", |
| |
| r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", |
| r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b", |
| |
| r"characteriz[e]?d?\s+by\s+([A-F])\b", |
| r"indicat[e]?s?\s+([A-F])\b", |
| |
| r"Option\s*([A-F])\b", |
| r"\b([A-F])\)\s*", |
| |
| r"^\s*([A-F])\s*$", |
| ] |
|
|
| |
| for pattern in extraction_patterns: |
| matches = re.findall(pattern, cleaned, re.IGNORECASE) |
| for match in matches: |
| |
| if isinstance(match, tuple): |
| match = match[0] if match[0] in "ABCDEF" else None |
| if match and match.upper() in "ABCDEF": |
| return match.upper() |
|
|
| |
| context_matches = re.findall(r"\b([A-F])\b", cleaned.upper()) |
| context_letters = [m for m in context_matches if m in "ABCDEF"] |
| if context_letters: |
| return context_letters[0] |
|
|
| |
| return None |
|
|
|
|
| def load_benchmark_questions(case_id: str) -> List[str]: |
| """Find all question files for a given case ID. |
| |
| Args: |
| case_id: The ID of the medical case |
| |
| Returns: |
| List of paths to question JSON files |
| """ |
| benchmark_dir = "MedMAX/benchmark/questions" |
| return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") |
|
|
|
|
| def count_total_questions() -> Tuple[int, int]: |
| """Count total number of cases and questions in benchmark. |
| |
| Returns: |
| Tuple containing (total_cases, total_questions) |
| """ |
| total_cases = len(glob.glob("MedMAX/benchmark/questions/*")) |
| total_questions = sum( |
| len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json")) |
| for case_id in os.listdir("MedMAX/benchmark/questions") |
| ) |
| return total_cases, total_questions |
|
|
|
|
| def create_inference_request( |
| question_data: Dict[str, Any], |
| case_details: Dict[str, Any], |
| case_id: str, |
| question_id: str, |
| worker_addr: str, |
| model_name: str, |
| raw_output: bool = False, |
| ) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]: |
| """Create and send inference request to worker. |
| |
| Args: |
| question_data: Dictionary containing question details and figures |
| case_details: Dictionary containing case information and figures |
| case_id: Identifier for the medical case |
| question_id: Identifier for the specific question |
| worker_addr: Address of the worker endpoint |
| model_name: Name of the model to use |
| raw_output: Whether to return raw model output |
| |
| Returns: |
| If raw_output is False: Tuple of (validated_answer, duration) |
| If raw_output is True: Dictionary with full inference details |
| """ |
| system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'. |
| """ |
|
|
| prompt = f"""Given the following medical case: |
| Please answer this multiple choice question: |
| {question_data['question']} |
| Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """ |
|
|
| try: |
| |
| if isinstance(question_data["figures"], str): |
| try: |
| required_figures = json.loads(question_data["figures"]) |
| except json.JSONDecodeError: |
| required_figures = [question_data["figures"]] |
| elif isinstance(question_data["figures"], list): |
| required_figures = question_data["figures"] |
| else: |
| required_figures = [str(question_data["figures"])] |
| except Exception as e: |
| print(f"Error parsing figures: {e}") |
| required_figures = [] |
|
|
| required_figures = [ |
| fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures |
| ] |
|
|
| |
| image_paths = [] |
| for figure in required_figures: |
| base_figure_num = "".join(filter(str.isdigit, figure)) |
| figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None |
|
|
| matching_figures = [ |
| case_figure |
| for case_figure in case_details.get("figures", []) |
| if case_figure["number"] == f"Figure {base_figure_num}" |
| ] |
|
|
| for case_figure in matching_figures: |
| subfigures = [] |
| if figure_letter: |
| subfigures = [ |
| subfig |
| for subfig in case_figure.get("subfigures", []) |
| if subfig.get("number", "").lower().endswith(figure_letter.lower()) |
| or subfig.get("label", "").lower() == figure_letter.lower() |
| ] |
| else: |
| subfigures = case_figure.get("subfigures", []) |
|
|
| for subfig in subfigures: |
| if "local_path" in subfig: |
| image_paths.append("MedMAX/data/" + subfig["local_path"]) |
|
|
| if not image_paths: |
| print(f"No local images found for case {case_id}, question {question_id}") |
| return "skipped", 0.0 |
|
|
| try: |
| start_time = time.time() |
|
|
| |
| processed_images = [process_image(path) for path in image_paths] |
|
|
| |
| conv = conv_templates["mistral_instruct"].copy() |
|
|
| |
| if "<image>" not in prompt: |
| text = prompt + "\n<image>" |
| else: |
| text = prompt |
|
|
| message = (text, processed_images[0], "Default") |
| conv.append_message(conv.roles[0], message) |
| conv.append_message(conv.roles[1], None) |
|
|
| prompt = conv.get_prompt() |
| headers = {"User-Agent": "LLaVA-Med Client"} |
| pload = { |
| "model": model_name, |
| "prompt": prompt, |
| "max_new_tokens": 150, |
| "temperature": 0.5, |
| "stop": conv.sep2, |
| "images": conv.get_images(), |
| "top_p": 1, |
| "frequency_penalty": 0.0, |
| "presence_penalty": 0.0, |
| } |
|
|
| max_retries = 3 |
| retry_delay = 5 |
| response_text = None |
|
|
| for attempt in range(max_retries): |
| try: |
| response = requests.post( |
| worker_addr + "/worker_generate_stream", |
| headers=headers, |
| json=pload, |
| stream=True, |
| timeout=30, |
| ) |
|
|
| complete_output = "" |
| for chunk in response.iter_lines( |
| chunk_size=8192, decode_unicode=False, delimiter=b"\0" |
| ): |
| if chunk: |
| data = json.loads(chunk.decode("utf-8")) |
| if data["error_code"] == 0: |
| output = data["text"].split("[/INST]")[-1] |
| complete_output = output |
| else: |
| print(f"\nError: {data['text']} (error_code: {data['error_code']})") |
| if attempt < max_retries - 1: |
| time.sleep(retry_delay) |
| break |
| return None, None |
|
|
| if complete_output: |
| response_text = complete_output |
| break |
|
|
| except (requests.exceptions.RequestException, json.JSONDecodeError) as e: |
| if attempt < max_retries - 1: |
| print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...") |
| time.sleep(retry_delay) |
| else: |
| print(f"\nFailed after {max_retries} attempts: {str(e)}") |
| return None, None |
|
|
| duration = time.time() - start_time |
|
|
| if raw_output: |
| inference_details = { |
| "raw_output": response_text, |
| "validated_answer": validate_answer(response_text), |
| "duration": duration, |
| "prompt": prompt, |
| "system_prompt": system_prompt, |
| "image_paths": image_paths, |
| "payload": pload, |
| } |
| return inference_details |
|
|
| return validate_answer(response_text), duration |
|
|
| except Exception as e: |
| print(f"Error in inference request: {str(e)}") |
| return None, None |
|
|
|
|
| def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: |
| """Remove image-related and large data from the payload to keep the log lean. |
| |
| Args: |
| payload: Original request payload dictionary |
| |
| Returns: |
| Cleaned payload dictionary with large data removed |
| """ |
| if not payload: |
| return None |
|
|
| |
| cleaned_payload = payload.copy() |
|
|
| |
| if "images" in cleaned_payload: |
| del cleaned_payload["images"] |
|
|
| return cleaned_payload |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--controller-address", type=str, default="http://localhost:21001") |
| parser.add_argument("--worker-address", type=str) |
| parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b") |
| parser.add_argument("--output-dir", type=str, default="benchmark_results") |
| parser.add_argument( |
| "--raw-output", action="store_true", help="Return raw model output without validation" |
| ) |
| parser.add_argument( |
| "--num-cases", |
| type=int, |
| help="Number of cases to process if looking at raw outputs", |
| default=2, |
| ) |
| args = parser.parse_args() |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| |
| live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json") |
| final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json") |
|
|
| |
| with open(live_log_filename, "w") as live_log_file: |
| live_log_file.write("[\n") |
|
|
| |
| logging.basicConfig( |
| filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"), |
| level=logging.INFO, |
| format="%(message)s", |
| ) |
|
|
| |
| if args.worker_address: |
| worker_addr = args.worker_address |
| else: |
| try: |
| requests.post(args.controller_address + "/refresh_all_workers") |
| ret = requests.post(args.controller_address + "/list_models") |
| models = ret.json()["models"] |
| ret = requests.post( |
| args.controller_address + "/get_worker_address", json={"model": args.model_name} |
| ) |
| worker_addr = ret.json()["address"] |
| print(f"Worker address: {worker_addr}") |
| except requests.exceptions.RequestException as e: |
| print(f"Failed to connect to controller: {e}") |
| return |
|
|
| if worker_addr == "": |
| print("No available worker") |
| return |
|
|
| |
| with open("MedMAX/data/updated_cases.json", "r") as file: |
| data = json.load(file) |
|
|
| total_cases, total_questions = count_total_questions() |
| print(f"\nStarting benchmark with {args.model_name}") |
| print(f"Found {total_cases} cases with {total_questions} total questions") |
|
|
| results = { |
| "model": args.model_name, |
| "timestamp": datetime.now().isoformat(), |
| "total_cases": total_cases, |
| "total_questions": total_questions, |
| "results": [], |
| } |
|
|
| cases_processed = 0 |
| questions_processed = 0 |
| correct_answers = 0 |
| skipped_questions = 0 |
| total_processed_entries = 0 |
|
|
| |
| for case_id, case_details in tqdm(data.items(), desc="Processing cases"): |
| question_files = load_benchmark_questions(case_id) |
| if not question_files: |
| continue |
|
|
| cases_processed += 1 |
| for question_file in tqdm( |
| question_files, desc=f"Processing questions for case {case_id}", leave=False |
| ): |
| with open(question_file, "r") as file: |
| question_data = json.load(file) |
| question_id = os.path.basename(question_file).split(".")[0] |
|
|
| questions_processed += 1 |
|
|
| |
| inference_result = create_inference_request( |
| question_data, |
| case_details, |
| case_id, |
| question_id, |
| worker_addr, |
| args.model_name, |
| raw_output=True, |
| ) |
|
|
| |
| if inference_result == ("skipped", 0.0): |
| skipped_questions += 1 |
| print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)") |
|
|
| |
| skipped_entry = { |
| "case_id": case_id, |
| "question_id": question_id, |
| "status": "skipped", |
| "reason": "No images found", |
| } |
| with open(live_log_filename, "a") as live_log_file: |
| json.dump(skipped_entry, live_log_file, indent=2) |
| live_log_file.write(",\n") |
|
|
| continue |
|
|
| |
| answer = inference_result["validated_answer"] |
| duration = inference_result["duration"] |
|
|
| |
| log_entry = { |
| "case_id": case_id, |
| "question_id": question_id, |
| "question": question_data["question"], |
| "correct_answer": question_data["answer"], |
| "raw_output": inference_result["raw_output"], |
| "validated_answer": answer, |
| "model_answer": answer, |
| "is_correct": answer == question_data["answer"] if answer else False, |
| "duration": duration, |
| "system_prompt": inference_result["system_prompt"], |
| "input_prompt": inference_result["prompt"], |
| "image_paths": inference_result["image_paths"], |
| "payload": clean_payload(inference_result["payload"]), |
| } |
|
|
| |
| with open(live_log_filename, "a") as live_log_file: |
| json.dump(log_entry, live_log_file, indent=2) |
| live_log_file.write(",\n") |
|
|
| |
| print(f"\nCase {case_id}, Question {question_id}") |
| print(f"Model Answer: {answer}") |
| print(f"Correct Answer: {question_data['answer']}") |
| print(f"Time taken: {duration:.2f}s") |
|
|
| |
| if answer == question_data["answer"]: |
| correct_answers += 1 |
|
|
| |
| results["results"].append(log_entry) |
| total_processed_entries += 1 |
|
|
| |
| if args.raw_output and cases_processed == args.num_cases: |
| break |
|
|
| |
| if args.raw_output and cases_processed == args.num_cases: |
| break |
|
|
| |
| with open(live_log_filename, "a") as live_log_file: |
| |
| live_log_file.seek(live_log_file.tell() - 2, 0) |
| live_log_file.write("\n]") |
|
|
| |
| results["summary"] = { |
| "cases_processed": cases_processed, |
| "questions_processed": questions_processed, |
| "total_processed_entries": total_processed_entries, |
| "correct_answers": correct_answers, |
| "skipped_questions": skipped_questions, |
| "accuracy": ( |
| correct_answers / (questions_processed - skipped_questions) |
| if (questions_processed - skipped_questions) > 0 |
| else 0 |
| ), |
| } |
|
|
| |
| with open(final_results_filename, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\nBenchmark Summary:") |
| print(f"Total Cases Processed: {cases_processed}") |
| print(f"Total Questions Processed: {questions_processed}") |
| print(f"Total Processed Entries: {total_processed_entries}") |
| print(f"Correct Answers: {correct_answers}") |
| print(f"Skipped Questions: {skipped_questions}") |
| print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%") |
| print(f"\nResults saved to {args.output_dir}") |
| print(f"Live log: {live_log_filename}") |
| print(f"Final results: {final_results_filename}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|