Spaces:
Running
Running
| """ | |
| Script to evaluate LLaVA on Visual Question Answering datasets. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| from tqdm import tqdm | |
| import torch | |
| # Add parent directory to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.llava import LLaVA | |
| from utils.eval_utils import evaluate_vqa, visualize_results, compute_metrics | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Evaluate LLaVA on VQA") | |
| parser.add_argument("--vision-model", type=str, default="openai/clip-vit-large-patch14-336", | |
| help="Vision model name or path") | |
| parser.add_argument("--language-model", type=str, default="lmsys/vicuna-7b-v1.5", | |
| help="Language model name or path") | |
| parser.add_argument("--questions-file", type=str, required=True, | |
| help="Path to questions JSON file") | |
| parser.add_argument("--image-folder", type=str, required=True, | |
| help="Path to folder containing images") | |
| parser.add_argument("--output-file", type=str, default=None, | |
| help="Path to save results") | |
| parser.add_argument("--load-8bit", action="store_true", | |
| help="Load language model in 8-bit mode") | |
| parser.add_argument("--load-4bit", action="store_true", | |
| help="Load language model in 4-bit mode") | |
| parser.add_argument("--device", type=str, default=None, | |
| help="Device to run the model on") | |
| parser.add_argument("--max-new-tokens", type=int, default=100, | |
| help="Maximum number of tokens to generate") | |
| parser.add_argument("--visualize", action="store_true", | |
| help="Visualize results") | |
| parser.add_argument("--num-examples", type=int, default=5, | |
| help="Number of examples to visualize") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Check if files exist | |
| if not os.path.exists(args.questions_file): | |
| raise FileNotFoundError(f"Questions file not found: {args.questions_file}") | |
| if not os.path.exists(args.image_folder): | |
| raise FileNotFoundError(f"Image folder not found: {args.image_folder}") | |
| # Initialize model | |
| print(f"Initializing LLaVA model...") | |
| model = LLaVA( | |
| vision_model_path=args.vision_model, | |
| language_model_path=args.language_model, | |
| device=args.device, | |
| load_in_8bit=args.load_8bit, | |
| load_in_4bit=args.load_4bit | |
| ) | |
| print(f"Model initialized on {model.device}") | |
| # Evaluate model | |
| print(f"Evaluating on {args.questions_file}...") | |
| eval_results = evaluate_vqa( | |
| model=model, | |
| questions_file=args.questions_file, | |
| image_folder=args.image_folder, | |
| output_file=args.output_file, | |
| max_new_tokens=args.max_new_tokens | |
| ) | |
| # Print results | |
| print("\nEvaluation Results:") | |
| print(f"Number of questions: {eval_results['num_questions']}") | |
| if eval_results['accuracy'] is not None: | |
| print(f"Accuracy: {eval_results['accuracy']:.4f}") | |
| # Compute additional metrics | |
| metrics = compute_metrics(eval_results['results']) | |
| print("\nAdditional Metrics:") | |
| for metric, value in metrics.items(): | |
| print(f"{metric}: {value:.4f}") | |
| # Visualize results if requested | |
| if args.visualize: | |
| print(f"\nVisualizing {args.num_examples} examples...") | |
| visualize_results( | |
| results=eval_results['results'], | |
| num_examples=args.num_examples, | |
| image_folder=args.image_folder | |
| ) | |
| if __name__ == "__main__": | |
| main() |