Spaces:
Sleeping
Sleeping
| """ | |
| Test the MultiModalAgent. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import json | |
| # Add the current directory to sys.path to import local modules | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| # Import the MultiModalAgent | |
| from agent import MultiModalAgent | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger('test_agent') | |
| def main(): | |
| """Test the MultiModalAgent with some sample questions.""" | |
| # Initialize the agent | |
| resource_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resource') | |
| agent = MultiModalAgent(resource_dir=resource_dir) | |
| # Load test questions from metadata.jsonl | |
| metadata_path = os.path.join(resource_dir, 'metadata.jsonl') | |
| test_questions = [] | |
| with open(metadata_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| entry = json.loads(line.strip()) | |
| if 'Question' in entry and 'file_name' in entry and entry['file_name']: | |
| test_questions.append({ | |
| 'task_id': entry.get('task_id'), | |
| 'question': entry['Question'], | |
| 'file_name': entry['file_name'], | |
| 'expected_answer': entry.get('Final answer') | |
| }) | |
| if len(test_questions) >= 5: # Limit to 5 questions | |
| break | |
| # If no questions with files were found, use some generic questions | |
| if not test_questions: | |
| test_questions = [ | |
| { | |
| 'question': "What's the oldest Blu-Ray in the inventory spreadsheet?", | |
| 'file_name': None, | |
| 'expected_answer': None | |
| }, | |
| { | |
| 'question': "How many files are in the resource directory?", | |
| 'file_name': None, | |
| 'expected_answer': None | |
| } | |
| ] | |
| # Test the agent with each question | |
| for i, q in enumerate(test_questions): | |
| question = q['question'] | |
| logger.info(f"Testing question {i+1}: {question}") | |
| answer = agent(question) | |
| logger.info(f"Answer: {answer}") | |
| if q['expected_answer']: | |
| logger.info(f"Expected answer: {q['expected_answer']}") | |
| if answer.strip() == q['expected_answer'].strip(): | |
| logger.info("Correct answer!") | |
| else: | |
| logger.warning("Incorrect answer.") | |
| logger.info("-" * 80) | |
| if __name__ == "__main__": | |
| main() | |