Spaces:
Runtime error
Runtime error
| pip install tenacity | |
| import sys | |
| import os | |
| import venv | |
| import subprocess | |
| import signal # Added for CLI stop handling | |
| # --- Venv Setup --- | |
| # Determine if we need to set up or reactivate the virtual environment | |
| VENV_DIR = ".venv" | |
| REQUIRED_PACKAGES = [ | |
| "gradio", | |
| "pandas", | |
| "requests", | |
| "tenacity", | |
| "Pillow", # For image handling (needed for dummy image in CLI test) | |
| ] | |
| def ensure_venv(): | |
| """Checks for venv, creates/installs if needed, and re-executes if not active.""" | |
| venv_path = os.path.abspath(VENV_DIR) | |
| # Check if the current Python executable is from the target venv | |
| is_in_venv = sys.prefix == venv_path | |
| venv_exists = os.path.isdir(venv_path) | |
| if is_in_venv: | |
| # Already running in the correct venv, proceed | |
| print(f"Running inside the '{VENV_DIR}' virtual environment.") | |
| return True # Indicate we are ready to proceed | |
| print(f"Not running inside the target '{VENV_DIR}' virtual environment.") | |
| if not venv_exists: | |
| print(f"Creating virtual environment in '{venv_path}'...") | |
| try: | |
| venv.create(venv_path, with_pip=True) | |
| print("Virtual environment created successfully.") | |
| except Exception as e: | |
| print(f"Error creating virtual environment: {e}", file=sys.stderr) | |
| sys.exit(1) # Exit if creation fails | |
| # Determine the Python executable path within the venv | |
| if sys.platform == "win32": | |
| python_executable = os.path.join(venv_path, "Scripts", "python.exe") | |
| pip_executable = os.path.join(venv_path, "Scripts", "pip.exe") | |
| else: | |
| python_executable = os.path.join(venv_path, "bin", "python") | |
| pip_executable = os.path.join(venv_path, "bin", "pip") | |
| if not os.path.exists(python_executable): | |
| print(f"Error: Python executable not found at '{python_executable}'. Venv creation might have failed.", file=sys.stderr) | |
| sys.exit(1) | |
| if not os.path.exists(pip_executable): | |
| print(f"Error: Pip executable not found at '{pip_executable}'. Venv creation might have failed.", file=sys.stderr) | |
| sys.exit(1) | |
| # Install requirements into the venv using pip from the venv | |
| print(f"Installing/checking required packages in '{venv_path}'...") | |
| install_command = [pip_executable, "install"] + REQUIRED_PACKAGES | |
| try: | |
| # Run pip install, capture output for clarity/debugging | |
| result = subprocess.run(install_command, check=True, capture_output=True, text=True, encoding='utf-8') | |
| print("Packages installed/verified successfully.") | |
| # print(result.stdout) # Uncomment to see pip output | |
| if result.stderr: | |
| # Show pip's stderr for warnings etc. | |
| print("--- pip stderr ---\n", result.stderr, "\n--- end pip stderr ---") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error installing packages using command: {' '.join(e.cmd)}", file=sys.stderr) | |
| print(f"Pip stdout:\n{e.stdout}", file=sys.stderr) | |
| print(f"Pip stderr:\n{e.stderr}", file=sys.stderr) | |
| sys.exit(1) # Exit if installation fails | |
| except Exception as e: | |
| print(f"An unexpected error occurred during package installation: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| # Re-execute the script using the venv's Python interpreter | |
| print(f"\nRestarting script using Python from '{venv_path}'...\n{'='*20}\n") | |
| script_path = os.path.abspath(sys.argv[0]) | |
| # os.execv replaces the current process, inheriting stdio etc. | |
| # Arguments must include the executable name as argv[0] for the new process | |
| try: | |
| os.execv(python_executable, [python_executable, script_path] + sys.argv[1:]) | |
| # If execv is successful, this line is never reached | |
| except OSError as e: | |
| print(f"Error restarting script with '{python_executable}': {e}", file=sys.stderr) | |
| # Fallback attempt with subprocess if execv fails (less ideal) | |
| print("Attempting restart with subprocess as fallback...") | |
| try: | |
| subprocess.check_call([python_executable, script_path] + sys.argv[1:]) | |
| sys.exit(0) # Exit cleanly if subprocess worked | |
| except Exception as sub_e: | |
| print(f"Subprocess restart also failed: {sub_e}", file=sys.stderr) | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"Unexpected error during script restart: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| # This should not be reached if re-execution happens | |
| return False # Indicate re-execution was attempted | |
| # --- Original Script Imports (ensure they are accessible after venv check) --- | |
| # It's generally okay to keep imports here, as the script restarts if not in venv | |
| import gradio as gr | |
| import json | |
| import logging | |
| import time | |
| import pandas as pd | |
| # import os # Already imported above | |
| import re | |
| import requests | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| import csv | |
| import io | |
| import tempfile # Added for JSONL download | |
| from urllib.parse import urlparse | |
| import base64 | |
| import mimetypes | |
| # import signal # Already imported above | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger("model_tester") | |
| class ModelEndpoint: | |
| """Simple model endpoint configuration.""" | |
| name: str | |
| api_url: str | |
| api_key: Optional[str] # API key can be optional (e.g., for local Ollama) | |
| model_id: str | |
| max_tokens: int = 1024 | |
| temperature: float = 0.0 | |
| file_upload_method: str = "JSON (Embedded Data)" # Options: "JSON (Embedded Data)", "Multipart Form Data" | |
| def to_dict(self): | |
| """Convert to dictionary.""" | |
| return { | |
| "name": self.name, | |
| "api_url": self.api_url, | |
| "model_id": self.model_id, | |
| "max_tokens": self.max_tokens, | |
| "temperature": self.temperature, | |
| } | |
| class TestCase: | |
| """Test case containing a key to query and actual value for evaluation.""" | |
| key: str # The input/prompt for the model | |
| value: str # The reference value/ground truth | |
| image_path_or_url: Optional[str] = None # Path or URL to an image for multimodal input | |
| id: Optional[str] = None # Unique ID for the test case | |
| class ModelResponse: | |
| """Model response for a test case.""" | |
| test_id: str | |
| model_name: str | |
| output: str | |
| latency: float | |
| class EvaluationResult: | |
| """Evaluation result from the LM judge.""" | |
| test_id: str | |
| champion_output: str | |
| challenger_output: str | |
| winner: str # "MODEL_A_WINS", "MODEL_B_WINS", or "TIE" (extracted from reasoning) | |
| confidence: float # Extracted confidence score (e.g., 4/5 -> 0.8) | |
| reasoning: str # Full response from the judge model | |
| # Global preprocessing settings (can be updated through UI) | |
| PREPROCESS_ENABLED = True | |
| MAX_LENGTH = 8000 | |
| REMOVE_SPECIAL_CHARS = True | |
| NORMALIZE_WHITESPACE = True | |
| # CSV preprocessing settings (specific to CSV format) | |
| DETECT_DELIMITER = True | |
| FIX_QUOTES = True | |
| REMOVE_CONTROL_CHARS = True | |
| NORMALIZE_NEWLINES = True | |
| SKIP_BAD_LINES = True | |
| SHOW_SAMPLE = True # Show sample data after loading & preprocessing | |
| # Global flag to signal stopping the test run | |
| STOP_REQUESTED = False | |
| # --- Text Preprocessing Function --- | |
| def preprocess_text(text, max_length=None, remove_special_chars=None, normalize_whitespace=None): | |
| """ | |
| Preprocess text (key or value) before using in prompts or comparisons. | |
| - Truncate to prevent token limits | |
| - Remove problematic characters | |
| - Normalize whitespace | |
| """ | |
| # Use global settings if not specified | |
| if max_length is None: max_length = MAX_LENGTH | |
| if remove_special_chars is None: remove_special_chars = REMOVE_SPECIAL_CHARS | |
| if normalize_whitespace is None: normalize_whitespace = NORMALIZE_WHITESPACE | |
| # Skip preprocessing if disabled globally | |
| if not PREPROCESS_ENABLED: | |
| return str(text) if text is not None else "" | |
| if text is None: return "" | |
| text = str(text) # Ensure it's a string | |
| # Truncate | |
| if len(text) > max_length: | |
| text = text[:max_length] + "... [truncated]" | |
| if remove_special_chars: | |
| # Remove control characters and other potentially problematic characters | |
| text = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', text) | |
| # Remove any XML/HTML-like tags that might interfere | |
| text = re.sub(r'<[^>]+>', '', text) | |
| if normalize_whitespace: | |
| # Normalize whitespace (multiple spaces, tabs, newlines to single space) | |
| text = re.sub(r'\s+', ' ', text) | |
| # But preserve paragraph breaks for readability (optional, maybe confusing) | |
| # text = re.sub(r'\n\s*\n', '\n\n', text) | |
| text = text.strip() | |
| return text | |
| # --- Model Runner Class --- | |
| class ModelRunner: | |
| """Handles model API calls.""" | |
| def __init__(self, endpoint: ModelEndpoint, prompt_template: str): | |
| self.endpoint = endpoint | |
| self.prompt_template = prompt_template | |
| def _load_and_encode_file(self, file_path_or_url: str) -> Tuple[Optional[str], Optional[str]]: | |
| """Loads file from path/URL, base64 encodes it, and returns (base64_string, mime_type) or (None, None).""" | |
| try: | |
| file_bytes = None | |
| if urlparse(file_path_or_url).scheme in ['http', 'https']: | |
| logger.info(f"Downloading file from URL: {file_path_or_url}") | |
| response = requests.get(file_path_or_url, stream=True, timeout=20) # Increased timeout | |
| response.raise_for_status() | |
| file_bytes = response.content | |
| logger.info(f"Successfully downloaded {len(file_bytes)} bytes from URL.") | |
| # Try to get mime type from headers first | |
| mime_type = response.headers.get('content-type') | |
| else: | |
| logger.info(f"Reading file from local path: {file_path_or_url}") | |
| if not os.path.exists(file_path_or_url): | |
| raise FileNotFoundError(f"File not found at path: {file_path_or_url}") | |
| with open(file_path_or_url, "rb") as f: | |
| file_bytes = f.read() | |
| mime_type, _ = mimetypes.guess_type(file_path_or_url) | |
| if not file_bytes: | |
| raise ValueError("Failed to load file bytes.") | |
| # Use application/octet-stream as a generic default if type cannot be guessed | |
| mime_type = mime_type or 'application/octet-stream' | |
| base64_data = base64.b64encode(file_bytes).decode('utf-8') | |
| logger.info(f"Successfully encoded file to base64. Mime type: {mime_type}") | |
| logger.info(f"Successfully loaded and encoded file from {file_path_or_url[:50]}... Type: {mime_type}, Size: {len(base64_data)} chars base64") | |
| return base64_data, mime_type | |
| except FileNotFoundError: | |
| logger.error(f"File not found: {file_path_or_url}") | |
| return None, None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to download file from URL {file_path_or_url}: {e}") | |
| return None, None | |
| except Exception as e: | |
| logger.error(f"Failed to load or encode file {file_path_or_url}: {e}") | |
| return None, None | |
| def _prepare_base64_data(self, file_path_or_url: str) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Loads file from path/URL, determines mime type, base64 encodes it. | |
| Returns (base64_string, mime_type) or (None, None) on error. | |
| (This is essentially the same logic as the original _load_and_encode_file) | |
| """ | |
| # Re-using the logic from _load_and_encode_file for now. | |
| # Consider consolidating later if _load_and_encode_file is removed. | |
| return self._load_and_encode_file(file_path_or_url) | |
| def _prepare_local_file_path(self, file_path_or_url: str) -> Optional[str]: | |
| """ | |
| Ensures a local file path exists for the given input. | |
| If input is a URL, downloads it to a temporary file. | |
| Returns the local path or None on error. Tracks temp files for cleanup. | |
| """ | |
| try: | |
| parsed_url = urlparse(file_path_or_url) | |
| if parsed_url.scheme in ['http', 'https']: | |
| logger.info(f"Downloading URL for multipart upload: {file_path_or_url}") | |
| response = requests.get(file_path_or_url, stream=True, timeout=30) # Increased timeout for downloads | |
| response.raise_for_status() | |
| # Create a temporary file | |
| # Suffix might help identify the file type, but not strictly necessary | |
| suffix = os.path.splitext(parsed_url.path)[1] | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) # Keep file after close | |
| with temp_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| temp_file.write(chunk) | |
| local_path = temp_file.name | |
| logger.info(f"Downloaded URL to temporary file: {local_path}") | |
| # Track temporary files for cleanup (needs instance variable) | |
| if not hasattr(self, '_temp_files'): | |
| self._temp_files = [] | |
| self._temp_files.append(local_path) | |
| return local_path | |
| else: | |
| # It's already a local path, verify it exists | |
| if os.path.exists(file_path_or_url): | |
| logger.info(f"Using existing local file path for multipart: {file_path_or_url}") | |
| return file_path_or_url | |
| else: | |
| logger.error(f"Local file path not found: {file_path_or_url}") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to download URL {file_path_or_url} for multipart: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error preparing local file path {file_path_or_url}: {e}") | |
| return None | |
| def _cleanup_temp_files(self): | |
| """Removes any temporary files created during URL downloads.""" | |
| if hasattr(self, '_temp_files'): | |
| for temp_path in self._temp_files: | |
| try: | |
| os.remove(temp_path) | |
| logger.info(f"Cleaned up temporary file: {temp_path}") | |
| except OSError as e: | |
| logger.warning(f"Failed to clean up temporary file {temp_path}: {e}") | |
| self._temp_files = [] # Reset the list | |
| def generate( | |
| self, | |
| test_case: TestCase, | |
| # Allow passing pre-loaded data, e.g., from the judge | |
| base64_data: Optional[str] = None, | |
| mime_type_override: Optional[str] = None | |
| ) -> ModelResponse: | |
| """Call the model API with the test case, potentially including file data based on endpoint configuration.""" | |
| start_time = time.time() | |
| # Variables to hold prepared file data based on method | |
| base64_data_loaded = None | |
| mime_type_loaded = None | |
| local_file_path_for_multipart = None # Path to file (potentially temporary) for multipart upload | |
| logger.info(f"Inside generate for model '{self.endpoint.name}', test_id: {test_case.id}. File path/URL from test case: {test_case.image_path_or_url}") | |
| try: | |
| # Determine the required upload method early | |
| upload_method = self.endpoint.file_upload_method | |
| logger.info(f"Using file upload method: '{upload_method}' for endpoint '{self.endpoint.name}'") | |
| # Preprocess the input key using the global settings | |
| preprocessed_key = preprocess_text(test_case.key) | |
| # Format prompt using the preprocessed key | |
| prompt = "" | |
| try: | |
| # For judge prompts, the "key" is already the full prompt | |
| if test_case.id and test_case.id.startswith("judge"): | |
| prompt = preprocessed_key # Use directly, assume already preprocessed if needed | |
| else: | |
| # Use simple replacement first, escaping existing braces in the key | |
| safe_key = preprocessed_key.replace("{", "{{").replace("}", "}}") | |
| prompt = self.prompt_template.replace("{key}", safe_key) | |
| except Exception as e: | |
| logger.warning(f"Error formatting prompt template with replace: {str(e)}. Falling back.") | |
| try: | |
| prompt = self.prompt_template.format(key=preprocessed_key) | |
| except Exception as e2: | |
| logger.error(f"Error formatting prompt template: {str(e2)}. Using concatenation.") | |
| prompt = f"{self.prompt_template}\n\nINPUT: {preprocessed_key}" | |
| # --- File Handling Logic (Prepare based on selected method) --- | |
| if test_case.image_path_or_url: | |
| logger.info(f"Test case {test_case.id} includes file reference: {test_case.image_path_or_url}. Preparing based on method '{upload_method}'.") | |
| if upload_method == "JSON (Embedded Data)": | |
| # Use existing logic for now, will move to _prepare_base64_data helper later | |
| # Prioritize pre-loaded data if provided (e.g., by judge) | |
| if base64_data: | |
| logger.info(f"Using pre-loaded base64 data for JSON method (test case {test_case.id}). Mime type override: {mime_type_override}") | |
| base64_data_loaded = base64_data | |
| mime_type_loaded = mime_type_override or 'application/octet-stream' # Use override or default | |
| else: | |
| logger.info(f"Loading file for JSON method: {test_case.image_path_or_url}") | |
| base64_data_loaded, mime_type_loaded = self._load_and_encode_file(test_case.image_path_or_url) | |
| logger.info(f"Result from _load_and_encode_file - Has data: {bool(base64_data_loaded)}, Mime type: {mime_type_loaded}") | |
| if base64_data_loaded is None: | |
| # Handle file loading failure | |
| logger.error(f"Failed to load file for JSON method (test case {test_case.id}). Returning error.") | |
| return ModelResponse(test_id=test_case.id or "unknown", model_name=self.endpoint.name, output=f"Error: Failed to load file {test_case.image_path_or_url} for JSON method", latency=time.time() - start_time) | |
| elif upload_method == "Multipart Form Data": | |
| # Call the actual helper function | |
| local_file_path_for_multipart = self._prepare_local_file_path(test_case.image_path_or_url) | |
| if local_file_path_for_multipart is None: | |
| logger.error(f"Failed to prepare local file for Multipart method (test case {test_case.id}). Returning error.") | |
| return ModelResponse(test_id=test_case.id or "unknown", model_name=self.endpoint.name, output=f"Error: Failed to prepare file {test_case.image_path_or_url} for Multipart method", latency=time.time() - start_time) | |
| else: | |
| logger.error(f"Unknown file_upload_method configured: {upload_method}") | |
| return ModelResponse(test_id=test_case.id or "unknown", model_name=self.endpoint.name, output=f"Error: Invalid file_upload_method '{upload_method}'", latency=time.time() - start_time) | |
| # else: No file involved in this test case, proceed with text-only call | |
| # --- API Call Routing --- | |
| response_text = "" | |
| try: | |
| if upload_method == "JSON (Embedded Data)": | |
| logger.info("Routing to _call_json_api.") | |
| # Call the new JSON API wrapper function | |
| response_text = self._call_json_api(prompt, base64_data_loaded, mime_type_loaded) | |
| elif upload_method == "Multipart Form Data": | |
| logger.info("Routing to _call_multipart_api.") | |
| # Call the new Multipart API function | |
| response_text = self._call_multipart_api(prompt, local_file_path_for_multipart) | |
| else: | |
| # Should have been caught during file prep, but defensive check | |
| logger.error(f"Invalid file_upload_method '{upload_method}' reached API call routing.") | |
| response_text = f"Error: Invalid file upload method configuration '{upload_method}'." | |
| # Exception handling remains the same, but applies to the new call structure | |
| except requests.exceptions.RequestException as req_err: | |
| logger.error(f"API request failed for {self.endpoint.name}: {req_err}") | |
| if hasattr(req_err, 'response') and req_err.response is not None: | |
| logger.error(f"Response status: {req_err.response.status_code}, Response text: {req_err.response.text[:500]}") | |
| response_text = f"Error: API request failed. Details: {str(req_err)}" | |
| except (KeyError, IndexError, TypeError, json.JSONDecodeError, ValueError) as parse_err: | |
| logger.error(f"Failed to parse response or invalid response structure from {self.endpoint.name}: {parse_err}") | |
| response_text = f"Error: Failed to parse API response. Details: {str(parse_err)}" | |
| except Exception as e: | |
| logger.error(f"Unexpected error calling API for {self.endpoint.name}: {str(e)}", exc_info=True) | |
| response_text = f"Error: An unexpected error occurred. Details: {str(e)}" | |
| end_time = time.time() | |
| # Clean up temporary files regardless of success or failure | |
| self._cleanup_temp_files() | |
| return ModelResponse( | |
| test_id=test_case.id or "unknown", # Ensure test_id is never None | |
| model_name=self.endpoint.name, | |
| output=str(response_text), # Ensure output is always string | |
| latency=end_time - start_time, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in generate method for {self.endpoint.name}: {str(e)}", exc_info=True) | |
| # Re-raise to trigger tenacity retry | |
| raise | |
| finally: | |
| # Ensure cleanup happens even if retry fails or other unexpected errors occur | |
| self._cleanup_temp_files() | |
| def _prepare_headers(self, is_json_request=True): | |
| """Prepares common headers. Adjusts Content-Type based on request type.""" | |
| # Start with common headers | |
| headers = {} | |
| # Only add Authorization header if api_key is present and not empty | |
| if self.endpoint.api_key and self.endpoint.api_key.strip(): | |
| # Check for specific API key types if needed (e.g., Anthropic uses x-api-key) | |
| if "anthropic" in self.endpoint.api_url.lower(): | |
| headers["x-api-key"] = self.endpoint.api_key | |
| headers["anthropic-version"] = "2023-06-01" # Required header | |
| elif "generativelanguage.googleapis.com" in self.endpoint.api_url.lower(): | |
| # Gemini API key is usually in the URL, not header | |
| pass | |
| else: | |
| # Default to Bearer token for OpenAI compatible and others | |
| headers["Authorization"] = f"Bearer {self.endpoint.api_key}" | |
| # Set Content-Type based on whether it's a JSON request or not (e.g., multipart) | |
| if is_json_request: | |
| headers["Content-Type"] = "application/json" | |
| # For multipart/form-data, requests library handles Content-Type automatically when 'files' param is used. | |
| # Add OpenRouter specific headers if applicable | |
| # Only add Authorization header if api_key is present and not empty | |
| if self.endpoint.api_key and self.endpoint.api_key.strip(): | |
| headers["Authorization"] = f"Bearer {self.endpoint.api_key}" | |
| # Add OpenRouter specific headers if applicable | |
| if self.endpoint.api_url and "openrouter.ai" in self.endpoint.api_url.lower(): | |
| # These might be optional now, but good practice | |
| headers["HTTP-Referer"] = "http://localhost" # Can be anything, localhost is common | |
| headers["X-Title"] = "Model A/B Testing Tool" | |
| return headers | |
| # --- New API Call Functions --- | |
| def _call_json_api(self, prompt: str, base64_data: Optional[str], mime_type: Optional[str]) -> str: | |
| """ | |
| Wrapper for JSON-based API calls. Detects API type and calls the appropriate formatter. | |
| """ | |
| logger.info(f"Executing JSON API call for endpoint: {self.endpoint.name}") | |
| headers = self._prepare_headers(is_json_request=True) # Ensure JSON content type | |
| # Determine API type based on URL (similar to previous logic) | |
| api_url_lower = self.endpoint.api_url.lower() if self.endpoint.api_url else "" | |
| is_openai_compatible = "/v1/chat/completions" in api_url_lower or \ | |
| "openai" in api_url_lower or \ | |
| "openrouter.ai" in api_url_lower or \ | |
| "mistral" in api_url_lower or \ | |
| "together.ai" in api_url_lower or \ | |
| "groq.com" in api_url_lower or \ | |
| "fireworks.ai" in api_url_lower or \ | |
| "deepinfra.com" in api_url_lower or \ | |
| "lmstudio.ai" in api_url_lower or \ | |
| ":1234/v1" in api_url_lower | |
| is_anthropic_compatible = "/v1/messages" in api_url_lower or "anthropic" in api_url_lower | |
| is_gemini = "generativelanguage.googleapis.com" in api_url_lower | |
| is_ollama = ("/api/generate" in api_url_lower and \ | |
| ("localhost:11434" in api_url_lower or "127.0.0.1:11434" in api_url_lower)) or \ | |
| ("ollama" in api_url_lower and "/api/generate" in api_url_lower) | |
| payload = {} | |
| api_url_to_call = self.endpoint.api_url # Default URL | |
| # Call appropriate formatting function | |
| if is_openai_compatible: | |
| payload = self._format_openai_json(prompt, base64_data, mime_type) | |
| elif is_anthropic_compatible: | |
| payload = self._format_anthropic_json(prompt, base64_data, mime_type) | |
| elif is_gemini: | |
| payload = self._format_gemini_json(prompt, base64_data, mime_type) | |
| # Gemini API key goes in URL parameter | |
| if self.endpoint.api_key: | |
| api_url_to_call = f"{self.endpoint.api_url}?key={self.endpoint.api_key}" | |
| else: | |
| raise ValueError("Gemini API key is required but not provided.") | |
| elif is_ollama: | |
| payload = self._format_ollama_json(prompt, base64_data) | |
| else: | |
| # Fallback: Use generic API call logic (which currently assumes OpenAI text-only) | |
| logger.warning(f"Could not determine specific JSON API type for {self.endpoint.api_url}. Using generic fallback.") | |
| # The generic call handles its own request formatting and execution | |
| return self._call_generic_api(prompt) # Return directly | |
| # Make the actual request | |
| try: | |
| payload_size_kb = len(json.dumps(payload)) / 1024 | |
| logger.info(f"Sending JSON request to {api_url_to_call}. Payload size: {payload_size_kb:.2f} KB") | |
| if payload_size_kb > 4000: logger.warning(f"Payload size ({payload_size_kb:.2f} KB) is large.") | |
| response = requests.post(api_url_to_call, headers=headers, json=payload, timeout=180) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Parse response based on API type (This parsing logic should ideally move too) | |
| # TODO: Move response parsing into dedicated functions or handle within formatters? | |
| if is_openai_compatible: | |
| if not result.get("choices") or not result["choices"][0].get("message"): raise ValueError("Invalid OpenAI response format") | |
| content = result["choices"][0]["message"].get("content") | |
| return content if content is not None else f"Error: Response content was null (Finish Reason: {result['choices'][0].get('finish_reason')})" | |
| elif is_anthropic_compatible: | |
| if not result.get("content"): raise ValueError("Invalid Anthropic response format") | |
| text_content = next((block.get("text", "") for block in result.get("content", []) if block.get("type") == "text"), "") | |
| return text_content if text_content else "[No text content found in response]" | |
| elif is_gemini: | |
| if not result.get("candidates"): raise ValueError("Invalid Gemini response format") | |
| candidate = result["candidates"][0] | |
| if not candidate.get("content") or not candidate["content"].get("parts"): raise ValueError("Invalid Gemini response format") | |
| text_response = "".join(part["text"] for part in candidate["content"]["parts"] if "text" in part) | |
| return text_response if text_response else "[No text content found in response]" | |
| elif is_ollama: | |
| if "response" in result: return result["response"] | |
| elif "error" in result: raise ValueError(f"Ollama API Error: {result['error']}") | |
| else: raise ValueError("Invalid Ollama response format") | |
| else: | |
| # Should not be reached if generic fallback worked | |
| raise ValueError("Unhandled API type in JSON response parsing.") | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"JSON API request failed: {str(e)}") | |
| if hasattr(e, 'response') and e.response is not None: logger.error(f"Response content: {e.response.text[:500]}") | |
| raise # Re-raise to be caught by the main generate method's handler | |
| except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e: | |
| logger.error(f"Failed to parse JSON API response: {str(e)}") | |
| logger.error(f"Full response: {result if 'result' in locals() else 'Response not available'}") | |
| raise # Re-raise | |
| def _call_multipart_api(self, prompt: str, local_file_path: Optional[str]) -> str: | |
| """ | |
| Handles API calls using multipart/form-data. | |
| (Placeholder - Needs implementation based on target API, e.g., Whisper) | |
| """ | |
| logger.info(f"Executing Multipart API call for endpoint: {self.endpoint.name}") | |
| if not local_file_path: | |
| return "Error: No local file path provided for multipart upload." | |
| # Prepare headers (Requests handles Content-Type for multipart) | |
| headers = self._prepare_headers(is_json_request=False) | |
| # Prepare data and files dictionary - THIS IS HIGHLY API-SPECIFIC | |
| # Example for OpenAI Whisper: | |
| data = {'model': self.endpoint.model_id} | |
| if prompt: # Whisper uses 'prompt' for context/hints | |
| data['prompt'] = prompt | |
| # Add other potential fields like 'language', 'response_format' based on API | |
| files = {} | |
| try: | |
| # Use a context manager to ensure the file is closed | |
| with open(local_file_path, 'rb') as f: | |
| files['file'] = (os.path.basename(local_file_path), f) | |
| logger.info(f"Preparing multipart request with file: {os.path.basename(local_file_path)}") | |
| # Make the request | |
| response = requests.post( | |
| self.endpoint.api_url, | |
| headers=headers, | |
| data=data, | |
| files=files, | |
| timeout=180 # Timeout for upload + processing | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Parse the response - AGAIN, API-SPECIFIC | |
| # Example for Whisper: | |
| if 'text' in result: | |
| return result['text'] | |
| else: | |
| raise ValueError(f"Unexpected response format from multipart API: {result}") | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Multipart API request failed: {str(e)}") | |
| if hasattr(e, 'response') and e.response is not None: logger.error(f"Response content: {e.response.text[:500]}") | |
| raise | |
| except (KeyError, ValueError, json.JSONDecodeError) as e: | |
| logger.error(f"Failed to parse Multipart API response: {str(e)}") | |
| logger.error(f"Full response: {result if 'result' in locals() else 'Response not available'}") | |
| raise | |
| except FileNotFoundError: | |
| logger.error(f"File not found for multipart upload: {local_file_path}") | |
| return f"Error: File not found at {local_file_path}" | |
| except Exception as e: | |
| logger.error(f"Unexpected error during multipart call: {e}", exc_info=True) | |
| raise | |
| # --- JSON Formatting Functions (Placeholders) --- | |
| def _format_openai_json(self, prompt: str, base64_data: Optional[str], mime_type: Optional[str]) -> Dict[str, Any]: | |
| """Formats the payload for OpenAI-compatible chat completion APIs.""" | |
| logger.debug("Formatting payload for OpenAI JSON") | |
| messages = [] | |
| if base64_data: | |
| mime_type = mime_type or 'image/jpeg' # Default mime type | |
| messages.append({ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}} | |
| ] | |
| }) | |
| else: | |
| messages.append({"role": "user", "content": prompt}) | |
| return { | |
| "model": self.endpoint.model_id, | |
| "messages": messages, | |
| "max_tokens": self.endpoint.max_tokens, | |
| "temperature": self.endpoint.temperature, | |
| } | |
| def _format_anthropic_json(self, prompt: str, base64_data: Optional[str], mime_type: Optional[str]) -> Dict[str, Any]: | |
| """Formats the payload for Anthropic messages API.""" | |
| logger.debug("Formatting payload for Anthropic JSON") | |
| content = [{"type": "text", "text": prompt}] | |
| if base64_data: | |
| mime_type = mime_type or 'image/jpeg' | |
| supported_mime_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp'] | |
| if mime_type not in supported_mime_types: | |
| logger.warning(f"MIME type '{mime_type}' may not be directly supported by Claude.") | |
| content.append({ | |
| "type": "image", | |
| "source": {"type": "base64", "media_type": mime_type, "data": base64_data} | |
| }) | |
| return { | |
| "model": self.endpoint.model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "max_tokens": self.endpoint.max_tokens, | |
| "temperature": self.endpoint.temperature, | |
| } | |
| def _format_gemini_json(self, prompt: str, base64_data: Optional[str], mime_type: Optional[str]) -> Dict[str, Any]: | |
| """Formats the payload for Google Gemini API.""" | |
| logger.debug("Formatting payload for Gemini JSON") | |
| parts = [{"text": prompt}] | |
| if base64_data: | |
| mime_type = mime_type or 'application/octet-stream' # Gemini supports various types | |
| parts.append({"inline_data": {"mime_type": mime_type, "data": base64_data}}) | |
| return { | |
| "contents": [{"parts": parts}], | |
| "generationConfig": { | |
| "temperature": self.endpoint.temperature, | |
| "maxOutputTokens": self.endpoint.max_tokens, | |
| } | |
| } | |
| def _format_ollama_json(self, prompt: str, base64_data: Optional[str]) -> Dict[str, Any]: | |
| """Formats the payload for Ollama generate API.""" | |
| logger.debug("Formatting payload for Ollama JSON") | |
| data = { | |
| "model": self.endpoint.model_id, | |
| "prompt": prompt, | |
| "stream": False, | |
| } | |
| if base64_data: | |
| data["images"] = [base64_data] # Ollama expects a list | |
| return data | |
| class LMJudge: | |
| """Uses a language model to judge between champion and challenger outputs.""" | |
| DEFAULT_EVALUATION_PROMPT = """ | |
| # Model Response Evaluation | |
| You are evaluating two AI model responses based on the input query, potentially an accompanying image, and potentially a reference value. | |
| ## Input Query | |
| {key} | |
| {image_context_section} | |
| {reference_section} | |
| ## Model A (Champion: {champion_name}) Response | |
| {champion_output} | |
| ## Model B (Challenger: {challenger_name}) Response | |
| {challenger_output} | |
| ## Evaluation Instructions | |
| Compare Model A and Model B based on the Input Query{reference_value_instruction}. Consider: | |
| 1. Relevance and accuracy in addressing the Input Query. | |
| {reference_value_criteria} | |
| {clarity_criteria_number}. Clarity, conciseness, and quality of the response. | |
| {overall_criteria_number}. Overall usefulness. | |
| ## Required Response Format | |
| You MUST start your response with a clear verdict and confidence rating: | |
| VERDICT: [Choose ONE: MODEL_A_WINS, MODEL_B_WINS, or TIE] | |
| CONFIDENCE: [Number]/5 (where 1=low confidence, 5=high confidence) | |
| Then provide a detailed explanation of your reasoning. Be explicit about which model performed better and why, or why they were tied. Include specific examples from each response that influenced your decision. | |
| Example format: | |
| VERDICT: MODEL_A_WINS | |
| CONFIDENCE: 4/5 | |
| [Your detailed reasoning here...] | |
| """ | |
| def __init__( | |
| self, | |
| endpoint: ModelEndpoint, | |
| evaluation_prompt_template: str = DEFAULT_EVALUATION_PROMPT, | |
| ): | |
| self.endpoint = endpoint | |
| self.evaluation_prompt_template = evaluation_prompt_template | |
| # The judge runner uses a simple placeholder template, as the full prompt | |
| # is formatted within the evaluate method before being passed as the 'key'. | |
| # Judge model runner needs access to the file loading method. | |
| self.model_runner = ModelRunner(endpoint, "{key}") # Pass-through template for prompt | |
| def evaluate( | |
| self, | |
| test_case: TestCase, | |
| champion_response: ModelResponse, | |
| challenger_response: ModelResponse | |
| ) -> EvaluationResult: | |
| """Evaluate champion vs challenger outputs using a dynamically built prompt.""" | |
| # Preprocess all inputs to ensure they're clean strings | |
| # Use the same preprocess_text function for consistency | |
| # Note: We don't pass the image to the judge, only the text inputs/outputs. | |
| preprocessed_key = preprocess_text(test_case.key) | |
| preprocessed_value = preprocess_text(test_case.value) # Preprocess reference value too | |
| preprocessed_champion = preprocess_text(champion_response.output) | |
| preprocessed_challenger = preprocess_text(challenger_response.output) | |
| # Prepare context for the evaluation prompt template | |
| has_reference = bool(preprocessed_value) | |
| reference_section_text = f"\n## Reference Value\n\n{preprocessed_value}\n" if has_reference else "\n## Reference Value\nN/A" | |
| reference_value_instruction_text = ' and Reference Value' if has_reference else '' | |
| reference_value_criteria_text = '2. Factual correctness compared to the Reference Value (if provided).' if has_reference else '' | |
| clarity_criteria_number_text = '3' if has_reference else '2' | |
| overall_criteria_number_text = '4' if has_reference else '3' | |
| # Add image context section if an image was provided in the original test case | |
| has_image = bool(test_case.image_path_or_url) | |
| image_context_section_text = "\n## Input Image\nAn image was provided with the input query. Consider it as context when evaluating the responses.\n" if has_image else "" | |
| # Format the evaluation prompt using the template and context | |
| try: | |
| evaluation_prompt = self.evaluation_prompt_template.format( | |
| key=preprocessed_key, | |
| image_context_section=image_context_section_text, # Added image context | |
| reference_section=reference_section_text, | |
| champion_name=champion_response.model_name, | |
| champion_output=preprocessed_champion, | |
| challenger_name=challenger_response.model_name, | |
| challenger_output=preprocessed_challenger, | |
| reference_value_instruction=reference_value_instruction_text, | |
| reference_value_criteria=reference_value_criteria_text, | |
| clarity_criteria_number=clarity_criteria_number_text, | |
| overall_criteria_number=overall_criteria_number_text | |
| ) | |
| except KeyError as e: | |
| logger.error(f"Missing key in judge prompt template: {e}. Using default prompt structure.") | |
| # Fallback to a basic structure if formatting fails | |
| evaluation_prompt = f"Evaluate Model A vs Model B.\nInput: {preprocessed_key}\nRef: {preprocessed_value}\nA: {preprocessed_champion}\nB: {preprocessed_challenger}\nFormat: VERDICT: [MODEL_A_WINS/MODEL_B_WINS/TIE]\nCONFIDENCE: [1-5]/5\nReasoning: ..." | |
| except Exception as e: | |
| logger.error(f"Error formatting judge prompt template: {e}. Using basic prompt.") | |
| evaluation_prompt = f"Evaluate Model A vs Model B.\nInput: {preprocessed_key}\nRef: {preprocessed_value}\nA: {preprocessed_champion}\nB: {preprocessed_challenger}\nFormat: VERDICT: [MODEL_A_WINS/MODEL_B_WINS/TIE]\nCONFIDENCE: [1-5]/5\nReasoning: ..." | |
| # Log the prompt for debugging (truncated) | |
| logger.info(f"Using Judge evaluation prompt (truncated): {evaluation_prompt[:500]}...") | |
| # Create a TestCase specifically for the judge call. | |
| # Crucially, pass the original image_path_or_url from the test_case. | |
| # The judge's model_runner.generate method will handle loading the file | |
| # based on the judge's endpoint.file_upload_method configuration. | |
| judge_test_case = TestCase( | |
| key=evaluation_prompt, | |
| value="", # No value needed for judge call itself | |
| image_path_or_url=test_case.image_path_or_url, # Pass original file reference | |
| id=f"judge-{test_case.id or 'unknown'}" | |
| ) | |
| # Call the judge's generate method. It will handle file loading internally. | |
| # We no longer need to pass base64_data or mime_type_override here. | |
| judge_response_obj = self.model_runner.generate( | |
| test_case=judge_test_case | |
| ) | |
| # Log the response for debugging (truncated) | |
| logger.info(f"Judge raw response (truncated): {judge_response_obj.output[:500]}...") | |
| # Parse the judge's decision from the raw output string | |
| parsed_result = self.parse_judge_response(judge_response_obj.output) | |
| return EvaluationResult( | |
| test_id=test_case.id or "unknown", | |
| champion_output=champion_response.output, # Store original, not preprocessed | |
| challenger_output=challenger_response.output, # Store original, not preprocessed | |
| winner=parsed_result["winner"], | |
| confidence=parsed_result["confidence"], | |
| reasoning=judge_response_obj.output, # Store the full raw response as reasoning | |
| ) | |
| def parse_judge_response(self, response_text: str) -> Dict[str, Any]: | |
| """ | |
| Parse the judge's raw response string to extract verdict and confidence. | |
| Uses more flexible regex patterns to handle various response formats. | |
| """ | |
| verdict = "UNDETERMINED" | |
| confidence = 0.0 | |
| # Log the first part of the response for debugging | |
| logger.debug(f"Parsing judge response (first 100 chars): {response_text[:100]}") | |
| # 1. Extract VERDICT (Case-insensitive search for the explicit line) | |
| verdict_match = re.search(r"^\s*VERDICT:\s*(MODEL_A_WINS|MODEL_B_WINS|TIE)\s*$", response_text, re.IGNORECASE | re.MULTILINE) | |
| if verdict_match: | |
| verdict = verdict_match.group(1).upper() | |
| logger.info(f"Parsed VERDICT line: {verdict}") | |
| else: | |
| # Fallback: Look for bracketed verdicts (common LLM-as-judge pattern) | |
| bracket_match = re.search(r"\[\[\s*(MODEL_A_WINS|MODEL_B_WINS|TIE)\s*\]\]", response_text, re.IGNORECASE) | |
| if bracket_match: | |
| verdict = bracket_match.group(1).upper() | |
| logger.info(f"Parsed bracketed verdict: {verdict}") | |
| else: | |
| # Fallback: Look for simpler A/B/TIE in brackets | |
| simple_bracket_match = re.search(r"\[\[\s*([AB]|TIE)\s*\]\]", response_text, re.IGNORECASE) | |
| if simple_bracket_match: | |
| verdict_text = simple_bracket_match.group(1).upper() | |
| if verdict_text == "A": verdict = "MODEL_A_WINS" | |
| elif verdict_text == "B": verdict = "MODEL_B_WINS" | |
| else: verdict = "TIE" | |
| logger.info(f"Parsed simple bracketed verdict: {verdict}") | |
| # 2. Extract CONFIDENCE (Case-insensitive search for the explicit line) | |
| confidence_match = re.search(r"^\s*CONFIDENCE:\s*(\d(?:\.\d)?)\s*/\s*5\s*$", response_text, re.IGNORECASE | re.MULTILINE) | |
| if confidence_match: | |
| try: | |
| confidence_score = float(confidence_match.group(1)) | |
| # Clamp confidence between 1 and 5, then normalize to 0.2-1.0 range | |
| confidence = max(0.2, min(1.0, confidence_score / 5.0)) | |
| logger.info(f"Parsed CONFIDENCE line: {confidence_score}/5 -> {confidence}") | |
| except ValueError: | |
| logger.warning(f"Could not parse CONFIDENCE value: {confidence_match.group(1)}") | |
| else: | |
| # Fallback: Look for rating/score patterns if confidence line missing | |
| score_match = re.search(r"(?:rating|score)[:\s]*(\d(?:\.\d)?)\s*/\s*(\d+)", response_text, re.IGNORECASE) | |
| if score_match: | |
| try: | |
| score = float(score_match.group(1)) | |
| scale = float(score_match.group(2)) | |
| if scale > 0: | |
| # Normalize to 0-1 range, clamping between 0.2 and 1.0 | |
| confidence = max(0.2, min(1.0, score / scale)) | |
| logger.info(f"Parsed score/rating: {score}/{scale} -> {confidence}") | |
| except ValueError: | |
| pass # Ignore if parsing fails | |
| # 3. Final checks and fallbacks if parsing failed | |
| if verdict == "UNDETERMINED": | |
| logger.warning(f"Could not reliably parse VERDICT from judge response: {response_text[:200]}...") | |
| # Simple keyword check as a last resort (less reliable) | |
| if "model a wins" in response_text.lower() and "model b wins" not in response_text.lower(): | |
| verdict = "MODEL_A_WINS" | |
| elif "model b wins" in response_text.lower() and "model a wins" not in response_text.lower(): | |
| verdict = "MODEL_B_WINS" | |
| elif "tie" in response_text.lower() or "comparable" in response_text.lower(): | |
| verdict = "TIE" | |
| # If we have a verdict but no confidence, assign a default moderate confidence | |
| if verdict != "UNDETERMINED" and confidence == 0.0: | |
| confidence = 0.6 # Default confidence when parsing fails but verdict is found | |
| logger.info(f"Could not parse CONFIDENCE, assigning default {confidence} for verdict {verdict}") | |
| # Log the final parsed values | |
| logger.info(f"Final parsed judge result - Winner: {verdict}, Confidence: {confidence:.2f}") | |
| return { | |
| "winner": verdict, | |
| "confidence": confidence, | |
| # Reasoning is the full response text, handled in evaluate method | |
| } | |
| class ResultAggregator: | |
| """Collects evaluation results and calculates summary statistics.""" | |
| def aggregate(self, evaluation_results: List[EvaluationResult]) -> Dict[str, Any]: | |
| """Aggregates results, calculating counts and percentages.""" | |
| total_evaluations = len(evaluation_results) | |
| verdict_counts = {"MODEL_A_WINS": 0, "MODEL_B_WINS": 0, "TIE": 0, "UNDETERMINED": 0, "JUDGE_ERROR": 0} | |
| confidence_sum = 0 | |
| valid_verdicts = 0 | |
| # Track which test cases had undetermined verdicts for logging | |
| undetermined_cases = [] | |
| judge_error_cases = [] | |
| for result in evaluation_results: | |
| verdict = result.winner # Use the pre-parsed winner | |
| if verdict in verdict_counts: | |
| verdict_counts[verdict] += 1 | |
| if verdict != "UNDETERMINED" and verdict != "JUDGE_ERROR": | |
| confidence_sum += result.confidence | |
| valid_verdicts += 1 | |
| elif verdict == "UNDETERMINED": | |
| undetermined_cases.append(result.test_id) | |
| elif verdict == "JUDGE_ERROR": | |
| judge_error_cases.append(result.test_id) | |
| else: | |
| # Should not happen if parsing is robust, but handle defensively | |
| logger.warning(f"Unexpected verdict '{verdict}' encountered for test_id {result.test_id}. Counting as UNDETERMINED.") | |
| verdict_counts["UNDETERMINED"] += 1 | |
| undetermined_cases.append(result.test_id) | |
| # Log summary of problematic cases | |
| if undetermined_cases: | |
| logger.warning(f"Found {len(undetermined_cases)} undetermined verdicts: {undetermined_cases[:5]}" + | |
| (f"... and {len(undetermined_cases)-5} more" if len(undetermined_cases) > 5 else "")) | |
| if judge_error_cases: | |
| logger.warning(f"Found {len(judge_error_cases)} judge errors: {judge_error_cases[:5]}" + | |
| (f"... and {len(judge_error_cases)-5} more" if len(judge_error_cases) > 5 else "")) | |
| # Calculate percentages based on determined verdicts only (excluding UNDETERMINED and JUDGE_ERROR) | |
| determined_verdicts = total_evaluations - verdict_counts["UNDETERMINED"] - verdict_counts["JUDGE_ERROR"] | |
| verdict_percentages = {} | |
| if determined_verdicts > 0: | |
| verdict_percentages["MODEL_A_WINS"] = round( | |
| (verdict_counts["MODEL_A_WINS"] / determined_verdicts) * 100, 2 | |
| ) | |
| verdict_percentages["MODEL_B_WINS"] = round( | |
| (verdict_counts["MODEL_B_WINS"] / determined_verdicts) * 100, 2 | |
| ) | |
| verdict_percentages["TIE"] = round( | |
| (verdict_counts["TIE"] / determined_verdicts) * 100, 2 | |
| ) | |
| else: | |
| verdict_percentages = {"MODEL_A_WINS": 0, "MODEL_B_WINS": 0, "TIE": 0} | |
| average_confidence = (confidence_sum / valid_verdicts) if valid_verdicts > 0 else 0 | |
| # Convert EvaluationResult objects to dictionaries for JSON serialization | |
| raw_eval_dicts = [] | |
| for res in evaluation_results: | |
| try: | |
| # Assuming EvaluationResult is a dataclass or has a simple structure | |
| raw_eval_dicts.append({ | |
| "test_id": res.test_id, | |
| "winner": res.winner, | |
| "confidence": res.confidence, | |
| "champion_output": res.champion_output, | |
| "challenger_output": res.challenger_output, | |
| "reasoning": res.reasoning, | |
| }) | |
| except AttributeError as e: | |
| logger.error(f"Error converting EvaluationResult to dict for test_id {res.test_id}: {e}") | |
| # Add a placeholder or skip | |
| raw_eval_dicts.append({"test_id": getattr(res, 'test_id', 'unknown'), "error": "Failed to serialize result"}) | |
| return { | |
| "total_evaluations": total_evaluations, | |
| "verdict_counts": verdict_counts, | |
| "verdict_percentages": verdict_percentages, # Based on determined verdicts | |
| "average_confidence": round(average_confidence, 3), # Avg confidence for non-undetermined/error | |
| "raw_evaluations": raw_eval_dicts # Keep raw for output (as dicts) | |
| } | |
| class ModelTester: | |
| """Main class that orchestrates the A/B testing process.""" | |
| def __init__( | |
| self, | |
| champion_endpoint: ModelEndpoint, | |
| challenger_endpoint: ModelEndpoint, | |
| judge_endpoint: ModelEndpoint, | |
| model_prompt_template: str, | |
| judge_prompt_template: str = LMJudge.DEFAULT_EVALUATION_PROMPT # Add judge template param | |
| ): | |
| self.champion_runner = ModelRunner(champion_endpoint, model_prompt_template) | |
| self.challenger_runner = ModelRunner(challenger_endpoint, model_prompt_template) | |
| # Pass the judge prompt template to the LMJudge constructor | |
| self.judge = LMJudge(judge_endpoint, evaluation_prompt_template=judge_prompt_template) | |
| self.aggregator = ResultAggregator() # Aggregator just collects/counts | |
| self.champion_endpoint = champion_endpoint | |
| self.challenger_endpoint = challenger_endpoint | |
| self.judge_endpoint = judge_endpoint | |
| def run_test( | |
| self, | |
| test_cases: List[TestCase], | |
| batch_size: int = 5, | |
| progress=None, | |
| batch_retry_attempts: int = 0, # Number of retry attempts for batches | |
| batch_backoff_factor: float = 2.0, # Exponential backoff factor | |
| batch_max_wait: int = 60, # Maximum wait time between retries in seconds | |
| batch_retry_trigger_strings: Optional[List[str]] = None # Strings that trigger a retry | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run the complete test process: generate responses, evaluate, aggregate. | |
| Includes batch retry mechanism for transient errors or problematic responses. | |
| Args: | |
| test_cases: List of test cases (potentially including image paths/URLs) | |
| batch_size: Number of test cases per batch | |
| progress: Gradio progress callback | |
| batch_retry_attempts: Max retries per batch | |
| batch_backoff_factor: Exponential backoff factor | |
| batch_max_wait: Max wait time between retries | |
| batch_retry_trigger_strings: List of strings triggering retry if found in outputs/reasoning | |
| """ | |
| all_evaluation_results: List[EvaluationResult] = [] | |
| champion_metrics = {"total_latency": 0.0, "total_output_chars": 0, "success_count": 0, "error_count": 0, "image_load_errors": 0} | |
| challenger_metrics = {"total_latency": 0.0, "total_output_chars": 0, "success_count": 0, "error_count": 0, "image_load_errors": 0} | |
| judge_metrics = {"total_latency": 0.0, "total_output_chars": 0, "success_count": 0, "error_count": 0} | |
| num_cases = len(test_cases) | |
| if num_cases == 0: | |
| logger.warning("No test cases provided to run_test.") | |
| return {"evaluations": [], "summary": {"error": "No test cases loaded."}} | |
| total_batches = (num_cases + batch_size - 1) // batch_size | |
| processed_case_count = 0 # Track actual processed cases for progress | |
| global STOP_REQUESTED # Access the global flag | |
| previous_update_payload = None # Store the last yielded update | |
| # Process in batches | |
| for i in range(0, num_cases, batch_size): | |
| if STOP_REQUESTED: | |
| logger.warning(f"Stop requested. Finishing early after processing {processed_case_count} cases.") | |
| if progress: | |
| progress(processed_case_count / num_cases, f"Stopping early after {processed_case_count} cases...") | |
| break # Exit the batch loop | |
| current_batch = test_cases[i:min(i + batch_size, num_cases)] | |
| batch_num = i // batch_size + 1 | |
| logger.info(f"--- Processing Batch {batch_num}/{total_batches} (Cases {i+1}-{min(i+batch_size, num_cases)}) ---") | |
| # Initialize retry counter and success flag for this batch | |
| retry_count = 0 | |
| batch_success = False | |
| batch_eval_results: List[EvaluationResult] = [] # Store results for *this successful batch attempt* | |
| # Process this batch with retries if configured | |
| while not batch_success and retry_count <= batch_retry_attempts: | |
| if retry_count > 0: | |
| # Calculate backoff delay with exponential increase, capped at max_wait | |
| delay = min(batch_backoff_factor ** (retry_count - 1), batch_max_wait) | |
| logger.info(f"Retrying batch {batch_num} (attempt {retry_count}/{batch_retry_attempts}) after {delay:.2f}s delay") | |
| if progress is not None: | |
| # Update progress based on already processed cases, not 'i' | |
| progress(processed_case_count / num_cases, f"Retrying Batch {batch_num} ({retry_count}/{batch_retry_attempts})") | |
| time.sleep(delay) | |
| else: | |
| if progress is not None: | |
| progress(processed_case_count / num_cases, f"Running Batch {batch_num}/{total_batches}") | |
| # Reset batch-specific stores for this attempt | |
| current_attempt_champ_responses: Dict[str, ModelResponse] = {} | |
| current_attempt_chall_responses: Dict[str, ModelResponse] = {} | |
| current_attempt_eval_results: List[EvaluationResult] = [] | |
| has_trigger_string_in_attempt = False | |
| # 1. Get responses from Champion and Challenger models for the current batch attempt | |
| for batch_idx, test_case in enumerate(current_batch): | |
| # Check if stop requested before processing this case | |
| if STOP_REQUESTED: | |
| logger.warning(f"Stop requested. Skipping remaining cases in batch {batch_num}.") | |
| break # Exit the inner loop for this batch | |
| # Generate a consistent ID if not present, using overall index 'i' + batch_idx | |
| case_id = test_case.id or f"case-{i + batch_idx + 1}" | |
| test_case.id = case_id # Ensure the test case object has the ID | |
| # --- Yield previous results before starting current case --- | |
| if previous_update_payload: | |
| # Check stop flag before yielding previous update | |
| if STOP_REQUESTED: | |
| logger.info("Stop requested before yielding previous update. Breaking batch.") | |
| break # Exit inner loop for this batch | |
| yield previous_update_payload # Indented correctly inside the if | |
| # --- Champion --- | |
| try: | |
| champ_resp = self.champion_runner.generate(test_case) | |
| current_attempt_champ_responses[case_id] = champ_resp | |
| # Only count metrics if not an image loading error generated by our code | |
| if not champ_resp.output.startswith("Error: Failed to load image"): | |
| champion_metrics["total_latency"] += champ_resp.latency | |
| champion_metrics["total_output_chars"] += len(champ_resp.output) | |
| if not champ_resp.output.startswith("Error:"): champion_metrics["success_count"] += 1 | |
| else: champion_metrics["error_count"] += 1 | |
| else: | |
| champion_metrics["image_load_errors"] += 1 | |
| champion_metrics["error_count"] += 1 # Count as an error | |
| except Exception as e: | |
| logger.error(f"Critical error generating champion response for case {case_id}: {e}", exc_info=True) | |
| current_attempt_champ_responses[case_id] = ModelResponse(case_id, self.champion_endpoint.name, f"Error: Generation failed critically - {e}", 0) | |
| champion_metrics["error_count"] += 1 | |
| # --- Challenger --- | |
| try: | |
| chall_resp = self.challenger_runner.generate(test_case) | |
| current_attempt_chall_responses[case_id] = chall_resp | |
| if not chall_resp.output.startswith("Error: Failed to load image"): | |
| challenger_metrics["total_latency"] += chall_resp.latency | |
| challenger_metrics["total_output_chars"] += len(chall_resp.output) | |
| if not chall_resp.output.startswith("Error:"): challenger_metrics["success_count"] += 1 | |
| else: challenger_metrics["error_count"] += 1 | |
| else: | |
| challenger_metrics["image_load_errors"] += 1 | |
| challenger_metrics["error_count"] += 1 | |
| except Exception as e: | |
| logger.error(f"Critical error generating challenger response for case {case_id}: {e}", exc_info=True) | |
| current_attempt_chall_responses[case_id] = ModelResponse(case_id, self.challenger_endpoint.name, f"Error: Generation failed critically - {e}", 0) | |
| challenger_metrics["error_count"] += 1 | |
| # --- Yield intermediate update --- | |
| # Check if the test case actually had an image associated | |
| image_to_display = test_case.image_path_or_url if test_case.image_path_or_url else None # Use None if no image | |
| # Calculate combined latency (handle potential errors where response might be missing) | |
| champ_resp = current_attempt_champ_responses.get(case_id) | |
| chall_resp = current_attempt_chall_responses.get(case_id) | |
| champ_lat = champ_resp.latency if champ_resp else 0 | |
| chall_lat = chall_resp.latency if chall_resp else 0 | |
| combined_latency = round(champ_lat + chall_lat, 3) | |
| # Removed the intermediate yield from here. It will be moved inside the evaluation loop below. | |
| # 2. Evaluate with LM Judge for the current batch attempt | |
| if progress is not None: | |
| progress((processed_case_count + len(current_batch) * 0.5) / num_cases, f"Evaluating Batch {batch_num}") | |
| for test_case in current_batch: | |
| case_id = test_case.id # Should have been set above | |
| champ_response = current_attempt_champ_responses.get(case_id) | |
| chall_response = current_attempt_chall_responses.get(case_id) | |
| # Skip evaluation if either model failed critically or had image load error | |
| if not champ_response or not chall_response or \ | |
| champ_response.output.startswith("Error:") or \ | |
| chall_response.output.startswith("Error:"): | |
| logger.warning(f"Skipping evaluation for case {case_id} due to generation error in one or both models.") | |
| # Create a dummy eval result indicating skip? Or just don't add? Let's not add. | |
| # We need a placeholder if retry depends on judge output, otherwise skip. | |
| # For simplicity now, we'll create an error result if a model failed. | |
| eval_reason = f"Skipped: Champion Error: {champ_response.output[:100]}... Challenger Error: {chall_response.output[:100]}..." if champ_response and chall_response else "Skipped: Model generation failed." | |
| current_attempt_eval_results.append(EvaluationResult( | |
| test_id=case_id, | |
| champion_output=champ_response.output if champ_response else "GENERATION FAILED", | |
| challenger_output=chall_response.output if chall_response else "GENERATION FAILED", | |
| winner="JUDGE_ERROR", # Count as judge error if models failed | |
| confidence=0.0, | |
| reasoning=eval_reason | |
| )) | |
| judge_metrics["error_count"] += 1 | |
| continue # Skip to next test case in batch | |
| # Check for trigger strings in model responses *before* calling judge if retry is enabled | |
| if batch_retry_attempts > 0 and batch_retry_trigger_strings: | |
| for trigger in batch_retry_trigger_strings: | |
| if trigger in champ_response.output or trigger in chall_response.output: | |
| logger.warning(f"Trigger string '{trigger}' found in model responses for case {case_id}. Batch will be retried.") | |
| has_trigger_string_in_attempt = True | |
| break # No need to check other triggers for this case | |
| if has_trigger_string_in_attempt: | |
| # Add a placeholder result indicating retry trigger | |
| current_attempt_eval_results.append(EvaluationResult( | |
| test_id=case_id, champion_output=champ_response.output, challenger_output=chall_response.output, | |
| winner="UNDETERMINED", confidence=0.0, reasoning=f"Retry triggered by model output string." | |
| )) | |
| continue # Skip judge call for this case if retry is triggered by models | |
| # If no model trigger, proceed to judge evaluation | |
| try: | |
| start_time = time.time() | |
| evaluation_result = self.judge.evaluate( | |
| test_case, | |
| champ_response, | |
| chall_response, | |
| ) | |
| judge_latency = time.time() - start_time | |
| judge_metrics["total_latency"] += judge_latency | |
| judge_metrics["total_output_chars"] += len(evaluation_result.reasoning) # Judge output length | |
| current_attempt_eval_results.append(evaluation_result) | |
| # --- Yield intermediate update AFTER judge evaluation for this case --- | |
| image_to_display = test_case.image_path_or_url if test_case.image_path_or_url else None | |
| champ_lat = round(champ_response.latency, 3) if champ_response else 0.0 | |
| chall_lat = round(chall_response.latency, 3) if chall_response else 0.0 | |
| # Use the latency calculated just before (judge_latency variable) | |
| judge_lat = round(judge_latency, 3) | |
| # Check stop flag before yielding intermediate update | |
| if STOP_REQUESTED: | |
| logger.info("Stop requested during evaluation loop. Breaking batch.") | |
| break # Exit evaluation loop for this batch | |
| # Yield the full evaluation result (as dict) for intermediate updates | |
| # Store this update to be yielded before the next case starts | |
| new_update_payload = { | |
| "type": "update", | |
| "image_path": image_to_display, | |
| "champ_latency": champ_lat, | |
| "chall_latency": chall_lat, | |
| "judge_latency": judge_lat, | |
| "evaluation": evaluation_result.__dict__ # Pass evaluation details as dict | |
| } | |
| previous_update_payload = new_update_payload # Store for next iteration | |
| # Yield the current update | |
| yield new_update_payload | |
| # Check for trigger strings in judge reasoning if retry is configured | |
| if batch_retry_attempts > 0 and batch_retry_trigger_strings and not has_trigger_string_in_attempt: | |
| for trigger in batch_retry_trigger_strings: | |
| if trigger in evaluation_result.reasoning: | |
| logger.warning(f"Trigger string '{trigger}' found in judge reasoning for case {case_id}. Batch will be retried.") | |
| has_trigger_string_in_attempt = True | |
| # Overwrite the winner to UNDETERMINED if retry triggered by judge | |
| evaluation_result.winner = "UNDETERMINED" | |
| evaluation_result.reasoning += "\n[Retry triggered by judge reasoning]" | |
| break | |
| # Update judge success/error counts based on final verdict (after potential trigger overwrite) | |
| if evaluation_result.winner != "UNDETERMINED" and evaluation_result.winner != "JUDGE_ERROR": judge_metrics["success_count"] += 1 | |
| else: judge_metrics["error_count"] += 1 # Count undetermined/judge_error as errors for judge metrics | |
| except Exception as e: | |
| logger.error(f"Error during judge evaluation for case {case_id}: {e}", exc_info=True) | |
| # Create a placeholder eval result indicating judge failure | |
| current_attempt_eval_results.append(EvaluationResult( | |
| test_id=case_id, | |
| champion_output=champ_response.output, | |
| challenger_output=chall_response.output, | |
| winner="JUDGE_ERROR", | |
| confidence=0.0, | |
| reasoning=f"Error: Judge evaluation failed critically - {e}" | |
| )) | |
| judge_metrics["error_count"] += 1 | |
| # If judge fails critically, maybe trigger retry? For now, just mark as error. | |
| # has_trigger_string_in_attempt = True # Option: Trigger retry on judge exception | |
| # --- Batch Retry Logic --- | |
| if has_trigger_string_in_attempt and retry_count < batch_retry_attempts: | |
| logger.warning(f"Batch {batch_num} attempt {retry_count+1} failed due to trigger strings. Retrying...") | |
| retry_count += 1 | |
| # Clear temporary results for this failed attempt, metrics were already counted above | |
| current_attempt_eval_results = [] | |
| continue # Go to the next iteration of the while loop (retry) | |
| else: | |
| # Conditions to accept the batch results: | |
| # 1. No trigger strings were found in this attempt. | |
| # 2. Trigger strings were found, but we've exhausted retry attempts. | |
| batch_success = True | |
| batch_eval_results = current_attempt_eval_results # Store the results of the successful (or final) attempt | |
| if has_trigger_string_in_attempt and retry_count >= batch_retry_attempts: | |
| logger.warning(f"Accepting batch {batch_num} results despite trigger strings after exhausting {batch_retry_attempts} retry attempts. Some results may be marked UNDETERMINED.") | |
| # Log summary for the completed batch attempt | |
| batch_summary = self.aggregator.aggregate(batch_eval_results) # Aggregate results of this specific batch | |
| log_prefix = f"Batch {batch_num} completed" | |
| if retry_count > 0: log_prefix += f" after {retry_count} retries" | |
| logger.info(f"{log_prefix}. Verdict Counts: {batch_summary['verdict_counts']}") | |
| # --- End of Batch Processing --- | |
| # Add the results of the successful batch attempt to the overall list | |
| all_evaluation_results.extend(batch_eval_results) | |
| processed_case_count += len(current_batch) # Update processed count *after* successful batch completion | |
| # 3. Aggregate final results across all successful batches | |
| aggregated_summary = self.aggregator.aggregate(all_evaluation_results) | |
| # 4. Calculate final metrics (using totals accumulated across all attempts) | |
| # Note: Parameter renamed from total_cases to processed_cases for clarity | |
| def calculate_avg_metrics(metrics, processed_cases): | |
| # Base counts on total cases attempted, errors include generation/image load issues | |
| total_attempts = metrics["success_count"] + metrics["error_count"] | |
| # Avg latency based on total attempts where latency was recorded (excludes critical failures before generation) | |
| valid_latency_runs = metrics["success_count"] + (metrics["error_count"] - metrics.get("image_load_errors", 0)) # Approx. | |
| avg_latency = round(metrics["total_latency"] / valid_latency_runs, 2) if valid_latency_runs > 0 else 0 | |
| # Avg chars based only on successful generations | |
| avg_chars = int(metrics["total_output_chars"] / metrics["success_count"]) if metrics["success_count"] > 0 else 0 | |
| # Success rate based on total test cases *processed* before stopping | |
| success_rate = round((metrics["success_count"] / processed_cases) * 100, 1) if processed_cases > 0 else 0 | |
| return { | |
| "avg_latency_s": avg_latency, | |
| "avg_output_chars": avg_chars, | |
| "success_rate_pct": success_rate, # Now calculated based on processed cases | |
| "errors": metrics["error_count"], | |
| "image_load_errors": metrics.get("image_load_errors", 0) | |
| } | |
| # Use processed_case_count for denominators as it reflects actual attempts before potential early stopping | |
| # Pass processed_case_count to the updated function parameter | |
| champion_avg_metrics = calculate_avg_metrics(champion_metrics, processed_case_count) | |
| challenger_avg_metrics = calculate_avg_metrics(challenger_metrics, processed_case_count) | |
| # Judge metrics are based on cases where evaluation was attempted | |
| judge_attempts = judge_metrics["success_count"] + judge_metrics["error_count"] | |
| judge_avg_metrics = calculate_avg_metrics(judge_metrics, judge_attempts) | |
| # 5. Determine overall decision based on aggregated results | |
| decision = "MAINTAIN_CHAMPION" # Default | |
| reason = "Insufficient data or challenger did not significantly outperform." | |
| win_margin_threshold = 5 # Challenger needs to win by at least 5% points | |
| min_determined_verdicts = max(3, int(0.1 * processed_case_count)) # Need at least 3 or 10% determined verdicts | |
| percentages = aggregated_summary["verdict_percentages"] | |
| determined_verdicts = processed_case_count - aggregated_summary["verdict_counts"].get("UNDETERMINED", 0) - aggregated_summary["verdict_counts"].get("JUDGE_ERROR", 0) | |
| if determined_verdicts >= min_determined_verdicts: | |
| champ_wins_pct = percentages.get("MODEL_A_WINS", 0) | |
| chall_wins_pct = percentages.get("MODEL_B_WINS", 0) | |
| ties_pct = percentages.get("TIE", 0) | |
| # Calculate confidence-weighted percentages if we have confidence scores | |
| avg_confidence = aggregated_summary["average_confidence"] | |
| confidence_factor = f" with {avg_confidence:.2f} average confidence" if avg_confidence > 0 else "" | |
| if chall_wins_pct > champ_wins_pct + win_margin_threshold: | |
| decision = "REPLACE_WITH_CHALLENGER" | |
| reason = f"Challenger won {chall_wins_pct:.1f}% vs Champion's {champ_wins_pct:.1f}%{confidence_factor} (>{win_margin_threshold}% margin based on {determined_verdicts} determined verdicts)." | |
| elif champ_wins_pct > chall_wins_pct + win_margin_threshold: | |
| decision = "MAINTAIN_CHAMPION" | |
| reason = f"Champion won {champ_wins_pct:.1f}% vs Challenger's {chall_wins_pct:.1f}%{confidence_factor} (based on {determined_verdicts} determined verdicts)." | |
| else: | |
| # Closer results, consider ties or maintain status quo | |
| decision = "MAINTAIN_CHAMPION" | |
| reason = f"Results close ({champ_wins_pct:.1f}% vs {chall_wins_pct:.1f}%, {ties_pct:.1f}% ties){confidence_factor}. Challenger did not show clear superiority (based on {determined_verdicts} determined verdicts)." | |
| else: | |
| # Not enough determined verdicts for a reliable decision | |
| decision = "MAINTAIN_CHAMPION" | |
| reason = f"Insufficient determined verdicts ({determined_verdicts}/{processed_case_count}, need >= {min_determined_verdicts}) to make a reliable decision. Defaulting to maintaining champion." | |
| # Log final summary | |
| logger.info(f"--- Final Aggregated Results ({processed_case_count} cases processed) ---") | |
| logger.info(f"Verdict Counts: {aggregated_summary['verdict_counts']}") | |
| logger.info(f"Verdict Percentages (Determined Only): {aggregated_summary['verdict_percentages']}") | |
| logger.info(f"Average Confidence (Determined Only): {aggregated_summary['average_confidence']:.3f}") | |
| logger.info(f"Champion Metrics: {champion_avg_metrics}") | |
| logger.info(f"Challenger Metrics: {challenger_avg_metrics}") | |
| logger.info(f"Judge Metrics: {judge_avg_metrics}") | |
| logger.info(f"Decision: {decision} - {reason}") | |
| if progress is not None: | |
| final_status = "Testing completed" if not STOP_REQUESTED else "Testing stopped early" | |
| progress(1.0, final_status) | |
| final_summary = { | |
| "total_test_cases_processed": processed_case_count, | |
| "total_test_cases_loaded": num_cases, | |
| "verdicts": aggregated_summary["verdict_counts"], | |
| "verdict_percentages": aggregated_summary["verdict_percentages"], | |
| "average_confidence": aggregated_summary["average_confidence"], | |
| "decision": decision, | |
| "reason": reason, | |
| "champion_metrics": champion_avg_metrics, | |
| "challenger_metrics": challenger_avg_metrics, | |
| "judge_metrics": judge_avg_metrics, | |
| "champion_name": self.champion_endpoint.name, | |
| "challenger_name": self.challenger_endpoint.name, | |
| "judge_name": self.judge_endpoint.name, | |
| } | |
| # Yield the final results dictionary | |
| yield { | |
| "type": "final", | |
| "evaluations": aggregated_summary["raw_evaluations"], | |
| "summary": final_summary | |
| } | |
| # --- Gradio UI Components & Logic --- | |
| def parse_test_data( | |
| file_obj, | |
| text_data, | |
| key_field_name: str = "key", | |
| value_field_name: str = "value", | |
| image_field_name: str = "image_url" # Added image field name parameter | |
| ) -> List[TestCase]: | |
| """ | |
| Parses test data from Gradio file upload or text input. | |
| Uses specified field names for key, value, and image path/URL. | |
| """ | |
| test_cases = [] | |
| raw_data = None | |
| if file_obj is not None: | |
| # Use the temporary file path provided by Gradio | |
| file_path = file_obj.name | |
| logger.info(f"Loading test data from uploaded file: {file_path}") | |
| try: | |
| # Determine file type from extension, not relying on original name if temp name is different | |
| _, file_ext = os.path.splitext(file_path) | |
| file_ext = file_ext.lower() | |
| if file_ext == ".json": | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| raw_data = json.load(f) | |
| elif file_ext == ".csv": | |
| # Read CSV into pandas DataFrame first for easier handling | |
| try: | |
| # Try detecting delimiter, handle potential bad lines | |
| # Use sensible defaults, allow overriding later if needed | |
| df = pd.read_csv( | |
| file_path, | |
| sep=None, # Auto-detect | |
| engine='python', | |
| on_bad_lines='warn', | |
| quoting=csv.QUOTE_MINIMAL, # Default quoting | |
| escapechar='\\' # Common escape character | |
| ) | |
| logger.info(f"CSV loaded successfully. Columns: {df.columns.tolist()}") | |
| # Convert NaN/NaT to None for cleaner processing -> convert to empty string later | |
| df = df.fillna('') | |
| # Convert DataFrame rows to list of dictionaries | |
| raw_data = df.to_dict(orient='records') | |
| except Exception as e: | |
| logger.error(f"Error reading CSV file '{file_path}': {e}") | |
| raise ValueError(f"Error reading CSV: {e}") | |
| elif file_ext in (".jsonl", ".ndjson"): | |
| # Handle JSONL (newline-delimited JSON) | |
| raw_data = [] | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line_num, line in enumerate(f): | |
| line = line.strip() | |
| if not line: continue # Skip empty lines | |
| try: | |
| item = json.loads(line) | |
| raw_data.append(item) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Skipping invalid JSON line #{line_num + 1} in file '{file_path}': {line[:100]}...") | |
| if not raw_data: | |
| raise ValueError("No valid JSON objects found in JSONL file.") | |
| else: | |
| allowed_extensions = ['.csv', '.json', '.jsonl', '.ndjson'] | |
| raise ValueError(f"Invalid file type ({file_ext}). Please upload a file that is one of these formats: {allowed_extensions}") | |
| except Exception as e: | |
| logger.error(f"Error processing uploaded file {file_path}: {e}", exc_info=True) | |
| raise ValueError(f"Failed to process file: {e}") | |
| elif text_data and text_data.strip(): | |
| logger.info("Loading test data from text input.") | |
| try: | |
| # Try parsing as JSON list first | |
| raw_data = json.loads(text_data) | |
| if not isinstance(raw_data, list): | |
| raise ValueError("Pasted text is valid JSON, but not a list of objects.") | |
| except json.JSONDecodeError as json_err: | |
| # If JSON fails, try treating it as line-delimited JSON (JSONL) | |
| logger.warning(f"Could not parse text as JSON list ({json_err}), trying as JSONL...") | |
| try: | |
| raw_data = [json.loads(line) for line in text_data.strip().splitlines() if line.strip()] | |
| if not raw_data: | |
| raise ValueError("No valid JSON objects found in text input lines.") | |
| except json.JSONDecodeError as line_err: | |
| logger.error(f"Invalid JSON format in text input (checked as list and line-by-line): {line_err}") | |
| raise ValueError(f"Invalid JSON format in text input. Ensure it's a list of objects `[ {{\"key\": ...}}, ... ]` or one JSON object per line.") | |
| except Exception as e: | |
| logger.error(f"Error processing text input data: {e}", exc_info=True) | |
| raise ValueError(f"Failed to process text data: {e}") | |
| else: | |
| raise ValueError("No test data provided. Please upload a file or paste JSON/JSONL.") | |
| # Convert raw_data (list of dicts) to TestCase objects | |
| if isinstance(raw_data, list): | |
| for i, item in enumerate(raw_data): | |
| if isinstance(item, dict): | |
| try: | |
| # Ensure the specified key field exists, value field is optional | |
| # 'id' is optional (defaults to None, ModelTester assigns later if needed) | |
| # 'image' field is optional | |
| key = item.get(key_field_name) | |
| if key is None: | |
| logger.warning(f"Skipping item {i+1} due to missing '{key_field_name}' field. Data: {item}") | |
| continue | |
| # Get image path/url if field exists and is not empty/None | |
| image_val = item.get(image_field_name) if image_field_name else None | |
| image_path_or_url = str(image_val).strip() if image_val and str(image_val).strip() else None | |
| test_cases.append(TestCase( | |
| id=str(item.get('id', f"item-{i+1}")), # Ensure ID is string, use item index | |
| key=str(key), # Ensure key is string | |
| value=str(item.get(value_field_name, '')), # Ensure value is string, default empty | |
| image_path_or_url=image_path_or_url, | |
| )) | |
| except Exception as e: | |
| logger.warning(f"Skipping item {i+1} due to error during TestCase creation: {e}. Data: {item}") | |
| else: | |
| logger.warning(f"Skipping item {i+1} as it is not a dictionary. Data: {item}") | |
| else: | |
| raise ValueError("Parsed data is not a list of test cases (expected list of dictionaries).") | |
| if not test_cases: | |
| raise ValueError("No valid test cases could be loaded from the provided data.") | |
| logger.info(f"Successfully loaded {len(test_cases)} test cases.") | |
| return test_cases | |
| def format_summary_output(summary_data: Dict[str, Any]) -> str: | |
| """Formats the summary dictionary into a readable string.""" | |
| if not summary_data or summary_data.get("error"): | |
| return f"Error generating summary: {summary_data.get('error', 'Unknown error')}" | |
| output = f"--- Test Summary ---\n" | |
| output += f"Champion: {summary_data.get('champion_name', 'N/A')}\n" | |
| output += f"Challenger: {summary_data.get('challenger_name', 'N/A')}\n" | |
| output += f"Judge: {summary_data.get('judge_name', 'N/A')}\n" | |
| output += f"Test Cases Loaded: {summary_data.get('total_test_cases_loaded', 'N/A')}\n" | |
| output += f"Test Cases Processed: {summary_data.get('total_test_cases_processed', 'N/A')}\n" | |
| output += "\nVerdicts (Based on Processed Cases):\n" | |
| for verdict, count in summary_data.get('verdicts', {}).items(): | |
| output += f" {verdict}: {count}\n" | |
| output += "\nVerdict Percentages (Based on Determined Verdicts):\n" | |
| determined = summary_data.get('total_test_cases_processed', 0) - \ | |
| summary_data.get('verdicts', {}).get('UNDETERMINED', 0) - \ | |
| summary_data.get('verdicts', {}).get('JUDGE_ERROR', 0) | |
| output += f" (Calculated from {determined} determined verdicts)\n" | |
| for verdict, pct in summary_data.get('verdict_percentages', {}).items(): | |
| output += f" {verdict}: {pct:.1f}%\n" | |
| avg_conf = summary_data.get('average_confidence', 0) | |
| output += f"\nAverage Confidence (Determined Only): {avg_conf:.3f}\n" | |
| output += "\nMetrics (Avg Latency / Avg Output Chars / Success Rate / Errors / Image Load Errors):\n" | |
| champ_metrics = summary_data.get('champion_metrics', {}) | |
| chall_metrics = summary_data.get('challenger_metrics', {}) | |
| judge_metrics = summary_data.get('judge_metrics', {}) # Judge metrics are calculated differently | |
| output += (f" Champion: {champ_metrics.get('avg_latency_s', 0):.2f}s / " | |
| f"{champ_metrics.get('avg_output_chars', 0)} / " | |
| f"{champ_metrics.get('success_rate_pct', 0):.1f}% / " | |
| f"{champ_metrics.get('errors', 0)} / " | |
| f"{champ_metrics.get('image_load_errors', 0)}\n") | |
| output += (f" Challenger: {chall_metrics.get('avg_latency_s', 0):.2f}s / " | |
| f"{chall_metrics.get('avg_output_chars', 0)} / " | |
| f"{chall_metrics.get('success_rate_pct', 0):.1f}% / " | |
| f"{chall_metrics.get('errors', 0)} / " | |
| f"{chall_metrics.get('image_load_errors', 0)}\n") | |
| # Judge metrics are slightly different (no image errors, success based on valid eval) | |
| output += (f" Judge: {judge_metrics.get('avg_latency_s', 0):.2f}s / " | |
| f"{judge_metrics.get('avg_output_chars', 0)} / " | |
| f"{judge_metrics.get('success_rate_pct', 0):.1f}% / " | |
| f"{judge_metrics.get('errors', 0)} (Errors + Undetermined)\n") | |
| output += f"\nDecision: {summary_data.get('decision', 'N/A')}\n" | |
| output += f"Reason: {summary_data.get('reason', 'N/A')}\n" | |
| return output | |
| def run_test_from_ui( | |
| # Model Configs (18 inputs now, including upload method) | |
| champ_name, champ_api_url, champ_model_id, champ_temp, champ_max_tokens, champ_file_upload_method, | |
| chall_name, chall_api_url, chall_model_id, chall_temp, chall_max_tokens, chall_file_upload_method, | |
| judge_name, judge_api_url, judge_model_id, judge_temp, judge_max_tokens, judge_file_upload_method, | |
| # API Key (1 input) | |
| api_key_input, | |
| # Prompts (2 inputs) | |
| model_prompt_template_input, | |
| judge_prompt_template_input, | |
| # Test Data (2 inputs) | |
| test_data_file, | |
| test_data_text, | |
| # Parameters (5 inputs) | |
| batch_size_input, | |
| batch_retry_attempts_input, | |
| batch_backoff_factor_input, | |
| batch_max_wait_input, | |
| batch_retry_trigger_strings_input, | |
| # Data Field Names (3 inputs) - Added image field name | |
| key_field_name_input, | |
| value_field_name_input, | |
| image_field_name_input, # Added image field name input | |
| # Gradio progress object | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| Handles the logic for running the A/B test triggered by the Gradio UI button. | |
| """ | |
| global STOP_REQUESTED | |
| STOP_REQUESTED = False # Reset stop flag at the beginning of each UI run | |
| logger.info("Starting test run from Gradio UI...") | |
| progress(0, desc="Initializing...") | |
| try: | |
| # 1. Get API Key from UI input (Treat as optional, let endpoint logic handle needs) | |
| # Also check environment variable as a fallback/override | |
| api_key_env = os.getenv("OPENROUTER_API_KEY") or os.getenv("ANTHROPIC_API_KEY") or os.getenv("GOOGLE_API_KEY") # Add other common key names if needed | |
| api_key_ui = str(api_key_input).strip() if api_key_input else None | |
| # Prioritize UI input if provided, otherwise use environment variable | |
| api_key = api_key_ui if api_key_ui else api_key_env | |
| if api_key_ui and api_key_env and api_key_ui != api_key_env: | |
| logger.warning("API Key provided via UI input overrides the environment variable.") | |
| elif api_key: | |
| logger.info(f"API Key found ({'UI input' if api_key_ui else 'environment variable'}).") | |
| else: | |
| logger.info("API Key not provided via UI input or environment variables. Only local/keyless endpoints will work.") | |
| progress(0.1, desc="Loading test data...") | |
| # 2. Load Test Cases (Pass field names from UI) | |
| try: | |
| key_field = str(key_field_name_input).strip() or "key" | |
| value_field = str(value_field_name_input).strip() or "value" | |
| image_field = str(image_field_name_input).strip() or "image_url" # Default if empty | |
| logger.info(f"Using data fields - Key: '{key_field}', Value: '{value_field}', Image: '{image_field}'") | |
| test_cases = parse_test_data(test_data_file, test_data_text, key_field, value_field, image_field) | |
| logger.info(f"Loaded {len(test_cases)} test cases.") | |
| except ValueError as e: | |
| logger.error(f"Failed to load test data: {e}") | |
| raise gr.Error(f"Test Data Error: {e}") | |
| except Exception as e: | |
| logger.exception("Unexpected error loading test data.") | |
| raise gr.Error(f"Unexpected error loading test data: {e}") | |
| if not test_cases: | |
| raise gr.Error("No valid test cases were loaded.") | |
| progress(0.2, desc="Configuring models...") | |
| # 3. Create Model Endpoints (Pass the potentially found api_key and upload method) | |
| try: | |
| # Helper to create endpoint, ensuring types and including upload method | |
| def create_ep(name, url, model_id, temp, max_tok, key, upload_method): | |
| # Strip whitespace from URL and model ID | |
| url = str(url).strip() if url else "" | |
| model_id = str(model_id).strip() if model_id else "" | |
| # Basic validation | |
| if not name: raise ValueError("Model Display Name cannot be empty.") | |
| if not url: raise ValueError(f"API URL cannot be empty for model '{name}'.") | |
| if not model_id: raise ValueError(f"Model ID cannot be empty for model '{name}'.") | |
| # Validate upload method | |
| if upload_method not in ["JSON (Embedded Data)", "Multipart Form Data"]: | |
| raise ValueError(f"Invalid file upload method '{upload_method}' for model '{name}'.") | |
| # Use the potentially loaded key and upload method | |
| return ModelEndpoint( | |
| name=str(name), api_url=url, api_key=key, model_id=model_id, | |
| temperature=float(temp), max_tokens=int(max_tok), | |
| file_upload_method=str(upload_method) # Add file upload method | |
| ) | |
| champion_endpoint = create_ep(champ_name, champ_api_url, champ_model_id, champ_temp, champ_max_tokens, api_key, champ_file_upload_method) | |
| challenger_endpoint = create_ep(chall_name, chall_api_url, chall_model_id, chall_temp, chall_max_tokens, api_key, chall_file_upload_method) | |
| judge_endpoint = create_ep(judge_name, judge_api_url, judge_model_id, judge_temp, judge_max_tokens, api_key, judge_file_upload_method) | |
| # Log endpoints being used (mask key if present, add upload method) | |
| logger.info(f"Champion Endpoint: {champion_endpoint.name}, URL: {champion_endpoint.api_url}, Model: {champion_endpoint.model_id}, Upload: {champion_endpoint.file_upload_method}, Key Provided: {'Yes' if champion_endpoint.api_key else 'No'}") | |
| logger.info(f"Challenger Endpoint: {challenger_endpoint.name}, URL: {challenger_endpoint.api_url}, Model: {challenger_endpoint.model_id}, Upload: {challenger_endpoint.file_upload_method}, Key Provided: {'Yes' if challenger_endpoint.api_key else 'No'}") | |
| logger.info(f"Judge Endpoint: {judge_endpoint.name}, URL: {judge_endpoint.api_url}, Model: {judge_endpoint.model_id}, Upload: {judge_endpoint.file_upload_method}, Key Provided: {'Yes' if judge_endpoint.api_key else 'No'}") | |
| except ValueError as ve: | |
| logger.error(f"Model Configuration Error: {ve}") | |
| raise gr.Error(f"Model Configuration Error: {ve}") | |
| except Exception as e: | |
| logger.error(f"Error creating ModelEndpoint objects: {e}", exc_info=True) | |
| raise gr.Error(f"Model Configuration Error: {e}") | |
| # 4. Instantiate ModelTester | |
| try: | |
| tester = ModelTester( | |
| champion_endpoint=champion_endpoint, | |
| challenger_endpoint=challenger_endpoint, | |
| judge_endpoint=judge_endpoint, | |
| model_prompt_template=str(model_prompt_template_input), | |
| judge_prompt_template=str(judge_prompt_template_input) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error instantiating ModelTester: {e}", exc_info=True) | |
| raise gr.Error(f"Tester Initialization Error: {e}") | |
| # 5. Run the Test | |
| batch_size = int(batch_size_input) if batch_size_input is not None and batch_size_input > 0 else 1 | |
| logger.info(f"Running test with {len(test_cases)} cases, batch size {batch_size}...") | |
| progress(0.3, desc="Running A/B test...") | |
| try: | |
| # Process batch retry parameters | |
| batch_retry_attempts = int(batch_retry_attempts_input) if batch_retry_attempts_input is not None else 0 | |
| batch_backoff_factor = float(batch_backoff_factor_input) if batch_backoff_factor_input is not None else 2.0 | |
| batch_max_wait = int(batch_max_wait_input) if batch_max_wait_input is not None else 60 | |
| # Process trigger strings (convert from comma-separated string to list) | |
| batch_retry_trigger_strings = None | |
| if batch_retry_trigger_strings_input and batch_retry_trigger_strings_input.strip(): | |
| batch_retry_trigger_strings = [s.strip().lower() for s in batch_retry_trigger_strings_input.split(',') if s.strip()] # Lowercase for case-insensitive match later | |
| logger.info(f"Using batch retry trigger strings: {batch_retry_trigger_strings}") | |
| # Make trigger strings case-insensitive in the run_test method check | |
| # (Already done in list comprehension above) | |
| # Iterate through the generator yielded by run_test | |
| final_results = None | |
| last_image_path = None | |
| last_champ_latency = "" | |
| last_chall_latency = "" | |
| last_judge_latency = "" | |
| last_winner = "" | |
| running_eval_results = [] # List to store results incrementally | |
| current_summary = "" # Placeholder for summary updates | |
| current_details_df = pd.DataFrame() # Placeholder for details updates | |
| try: # Inner try for the loop | |
| for result_update in tester.run_test( | |
| test_cases, | |
| batch_size=batch_size, | |
| progress=progress, | |
| batch_retry_attempts=batch_retry_attempts, | |
| batch_backoff_factor=batch_backoff_factor, | |
| batch_max_wait=batch_max_wait, | |
| batch_retry_trigger_strings=batch_retry_trigger_strings | |
| ): | |
| if STOP_REQUESTED: # Check stop flag during iteration | |
| logger.info("Stop requested, halting UI updates.") | |
| break # Exit the for loop | |
| if result_update.get("type") == "update": | |
| # Extract monitoring values | |
| last_image_path = result_update.get("image_path") | |
| last_champ_latency = str(result_update.get("champ_latency", "")) | |
| last_chall_latency = str(result_update.get("chall_latency", "")) | |
| last_judge_latency = str(result_update.get("judge_latency", "")) | |
| # Process evaluation details for incremental updates | |
| evaluation = result_update.get("evaluation") # Get the full evaluation dict | |
| if evaluation: | |
| running_eval_results.append(evaluation) # Add current result to list | |
| last_winner = str(evaluation.get("winner", "N/A")) | |
| # Format summary for the current case | |
| current_summary = ( | |
| f"--- Last Processed Case ---\n" | |
| f"Case ID: {evaluation.get('test_id', 'N/A')}\n" | |
| f"Winner: {last_winner}\n" | |
| f"Confidence: {evaluation.get('confidence', 0.0):.2f}\n" | |
| f"Reasoning Snippet: {evaluation.get('reasoning', '')[:200]}..." | |
| ) | |
| # Update DataFrame with all results so far | |
| try: | |
| display_columns_update = ['test_id', 'winner', 'confidence', 'champion_output', 'challenger_output', 'reasoning'] | |
| current_details_df = pd.DataFrame(running_eval_results) | |
| for col in display_columns_update: | |
| if col not in current_details_df.columns: | |
| current_details_df[col] = None | |
| current_details_df = current_details_df[display_columns_update] # Reorder/select columns | |
| except Exception as df_err: | |
| logger.error(f"Error creating incremental DataFrame: {df_err}") | |
| current_details_df = pd.DataFrame([{"Error": "Failed to update details"}]) | |
| else: | |
| last_winner = "Update Error" | |
| current_summary = "Error: No evaluation data in update." | |
| # Yield 8 values for incremental update | |
| yield current_summary, current_details_df, last_image_path, last_champ_latency, last_chall_latency, last_judge_latency, last_winner, running_eval_results # Yield current results to state | |
| elif result_update.get("type") == "final": | |
| # Store final results | |
| final_results = result_update | |
| # Don't break here, let the loop finish naturally to reach finally | |
| else: | |
| logger.warning(f"Received unexpected update type from run_test: {result_update.get('type')}") | |
| # End of the main 'for' loop | |
| except Exception as loop_err: # Catch errors *during* loop execution | |
| if STOP_REQUESTED: | |
| logger.warning(f"Caught exception during test loop after stop request (likely progress bar issue): {loop_err}") | |
| pass # Suppress error, proceed to finally | |
| else: | |
| logger.exception("An unexpected error occurred during the test execution loop.") | |
| raise loop_err # Re-raise unexpected errors | |
| finally: # Executes after loop finishes (normally or via break) or if suppressed exception occurred | |
| # --- Post-Loop Processing --- | |
| if 'result_update' not in locals() and not STOP_REQUESTED: # Handle case where loop didn't run and wasn't stopped | |
| final_results = None | |
| if STOP_REQUESTED: | |
| logger.info("Test run stopped by user.") | |
| summary_output = "Test run stopped by user." | |
| details_df = current_details_df if 'current_details_df' in locals() and not current_details_df.empty else pd.DataFrame(columns=['test_id', 'winner', 'confidence', 'champion_output', 'challenger_output', 'reasoning']) | |
| raw_evals = running_eval_results if 'running_eval_results' in locals() else [] | |
| elif 'final_results' not in locals() or final_results is None: | |
| logger.error("Test run finished, but no final results structure was received.") | |
| summary_output = "Test Execution Error: No final results generated." | |
| details_df = pd.DataFrame(columns=['test_id', 'winner', 'confidence', 'champion_output', 'challenger_output', 'reasoning']) | |
| raw_evals = [] | |
| last_image_path = last_image_path if 'last_image_path' in locals() else None | |
| last_champ_latency = last_champ_latency if 'last_champ_latency' in locals() else "" | |
| last_chall_latency = last_chall_latency if 'last_chall_latency' in locals() else "" | |
| last_judge_latency = last_judge_latency if 'last_judge_latency' in locals() else "" | |
| last_winner = last_winner if 'last_winner' in locals() else "" | |
| else: | |
| # Completed normally, process final results | |
| logger.info("Test run completed normally.") | |
| summary_data = final_results.get("summary", {}) | |
| raw_evals = final_results.get("evaluations", []) | |
| summary_output = format_summary_output(summary_data) | |
| display_columns = ['test_id', 'winner', 'confidence', 'champion_output', 'challenger_output', 'reasoning'] | |
| try: | |
| if raw_evals: | |
| details_df = pd.DataFrame(raw_evals) | |
| for col in display_columns: | |
| if col not in details_df.columns: | |
| details_df[col] = None | |
| details_df = details_df[display_columns] | |
| else: | |
| details_df = pd.DataFrame(columns=display_columns) | |
| summary_output += "\n\nNote: No evaluation results were generated." | |
| except Exception as df_err: | |
| logger.error(f"Error creating DataFrame from final results: {df_err}") | |
| summary_output += f"\n\nError displaying detailed results: {df_err}" | |
| details_df = pd.DataFrame(columns=display_columns) | |
| raw_evals = [] | |
| # Ensure monitoring variables exist | |
| last_image_path = last_image_path if 'last_image_path' in locals() else None | |
| last_champ_latency = last_champ_latency if 'last_champ_latency' in locals() else "" | |
| last_chall_latency = last_chall_latency if 'last_chall_latency' in locals() else "" | |
| last_judge_latency = last_judge_latency if 'last_judge_latency' in locals() else "" | |
| last_winner = last_winner if 'last_winner' in locals() else "" | |
| raw_evals = raw_evals if 'raw_evals' in locals() else [] | |
| # Final yield statement - yields 8 values (conditionally None for raw_evals if stopped) | |
| yield summary_output, details_df, last_image_path, last_champ_latency, last_chall_latency, last_judge_latency, last_winner, raw_evals | |
| # End of inner try...except...finally | |
| except Exception as test_exec_err: # Corresponds to try at line 1701 | |
| logger.exception("An error occurred during the main test execution phase.") | |
| # Potentially yield an error state back to the UI if needed | |
| # For now, just re-raise to be caught by the outermost handler | |
| raise test_exec_err | |
| # --- Outer Exception Handling --- | |
| # --- Outer Exception Handling (indentation matches outer 'try' at line 1614) --- | |
| except gr.Error as e: # Catch Gradio-specific errors first | |
| logger.error(f"Gradio Error: {e}") | |
| error_message = str(e) | |
| error_df = pd.DataFrame([{"Error": error_message}]) | |
| yield error_message, error_df, None, None, None, None, None, None # Yield 8 error values | |
| except Exception as e: # Catch any other exceptions (setup, re-raised from inner loop) | |
| logger.exception("An unexpected error occurred in run_test_from_ui.") | |
| error_message = f"An unexpected error occurred: {e}" | |
| error_df = pd.DataFrame([{"Error": error_message}]) | |
| yield error_message, error_df, None, None, None, None, None, None # Yield 8 error values | |
| finally: # Outer finally (indentation matches outer 'try' at line 1614) | |
| # Ensure the stop requested flag is reset regardless of how the function exits | |
| STOP_REQUESTED = False | |
| # Function to be called by the Stop button | |
| # --- Helper Function for Downloads --- | |
| # --- Stop Request Handling --- | |
| def request_stop(): | |
| """Sets the global STOP_REQUESTED flag and returns a status message.""" | |
| global STOP_REQUESTED | |
| status_message = "" | |
| if not STOP_REQUESTED: | |
| STOP_REQUESTED = True | |
| logger.warning("Stop requested via UI button.") | |
| status_message = "Stop request received. Finishing current batch..." | |
| else: | |
| logger.warning("Stop already requested.") | |
| status_message = "Stop already requested. Please wait..." | |
| return status_message | |
| # Function to generate JSONL file for download | |
| from typing import Optional # Add Optional import | |
| def generate_jsonl_download(results_list: Optional[List[Dict[str, Any]]]) -> Optional[gr.File]: | |
| """ | |
| Takes an optional list of evaluation result dictionaries, saves them as a JSONL file, | |
| and returns a Gradio File object for download. Returns None if input is None (e.g., run stopped). | |
| """ | |
| if not results_list: | |
| # Still handle the case of a completed run with zero results | |
| logger.warning("generate_jsonl_download called with empty results list (run completed but no results).") | |
| results_list = [] # Ensure it's an empty list for file creation below | |
| logger.info(f"generate_jsonl_download received results_list: type={type(results_list)}, len={len(results_list) if results_list is not None else 'None'}") | |
| try: | |
| # Use io.StringIO to build the JSONL string in memory | |
| jsonl_content = io.StringIO() | |
| for result in results_list: | |
| # Ensure result is serializable (convert dataclasses if needed) | |
| # The results passed should already be dicts from ModelTester.run_test final yield | |
| if isinstance(result, dict): | |
| serializable_result = result | |
| else: | |
| # Handle unexpected types if necessary | |
| logger.warning(f"Skipping non-dict item in results: {type(result)}") | |
| continue | |
| jsonl_content.write(json.dumps(serializable_result) + '\n') | |
| # Get the string content | |
| jsonl_string = jsonl_content.getvalue() | |
| jsonl_content.close() | |
| logger.info(f"Generated JSONL string length: {len(jsonl_string)}") | |
| # Create a temporary file path | |
| timestamp = time.strftime("%Y%m%d-%H%M%S") | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join(temp_dir, f"abx_results_{timestamp}.jsonl") | |
| # Write the string content to the file | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(jsonl_string) | |
| logger.info(f"Generated JSONL file for download at: {file_path}") | |
| logger.info(f"--- generate_jsonl_download returning file: {file_path}") | |
| # Return the file path wrapped in gr.File for download | |
| # Note: Gradio handles the cleanup of the temp file after download | |
| return gr.File(value=file_path, label="Download Results (JSONL)") | |
| except Exception as e: | |
| logger.error(f"Error generating JSONL file: {e}", exc_info=True) | |
| # How to signal error to user? Gradio download button doesn't easily show errors. | |
| # Maybe return None or raise an error that Gradio might catch? | |
| # Returning None might just make the download fail silently. | |
| # Let's re-raise for now, Gradio might handle it. | |
| raise gr.Error(f"Failed to generate JSONL download: {e}") | |
| def _generate_download_wrapper(results_state, *args): | |
| """Wrapper to call generate_jsonl_download, ignoring extra args from .then().""" | |
| # *args will capture any extra positional arguments Gradio might pass | |
| logger.info(f"Download wrapper called. results_state type: {type(results_state)}, len: {len(results_state) if isinstance(results_state, list) else 'N/A'}. Ignoring {len(args)} extra args.") | |
| return generate_jsonl_download(results_state) | |
| # --- UI Creation --- | |
| # Removed incorrectly indented line | |
| def create_ui(): | |
| """Creates the Gradio web interface for the A/B testing tool.""" | |
| logger.info("Creating Gradio UI...") | |
| # Default values for UI components | |
| default_api_url_openrouter = "https://openrouter.ai/api/v1/chat/completions" | |
| default_api_url_ollama = "http://localhost:11434/api/generate" # Default Ollama URL | |
| default_model_prompt = "User: {key}\nAssistant:" # Example prompt | |
| # Use the default judge prompt from the LMJudge class | |
| default_judge_prompt = LMJudge.DEFAULT_EVALUATION_PROMPT | |
| css = """ | |
| .model-config-group .gr-form { background-color: #f0f0f0; padding: 10px; border-radius: 5px; margin-bottom: 10px; } | |
| .model-config-group .gr-form > :first-child { font-weight: bold; margin-bottom: 5px; } /* Style the label */ | |
| .results-box { border: 1px solid #ccc; padding: 15px; border-radius: 5px; margin-top: 15px; } | |
| .api-key-warning { color: #cc5500; font-weight: bold; margin-bottom: 15px; } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# A/B x Judge: AI Testing & Auto-Evaluation") | |
| gr.Markdown( | |
| "1) Configure Champion, Challenger, and Judge.\n" | |
| "2) Provide Test Data & Reference Input.\n" | |
| "3) Run Evaluations & Compare Performance." | |
| ) | |
| gr.Markdown( | |
| """**API Key**: Enter as needed for Cloud Endpoints; env defaults auto-evaluated (see: [code](https://github.com/rabbidave/ZeroDay.Tools/blob/main/ABxJudge.py))""", | |
| elem_classes="api-key-warning" | |
| ) | |
| gr.Markdown( | |
| """**Multimodal Input**: | |
| 1. Ensure your test data (CSV/JSON/JSONL) includes a column/field containing the **local path** or **public URL** to the file. | |
| 2. Specify this column/field name in the 'Input Field Name' box below. | |
| 3. Ensure your models and endpoints support multimodal input | |
| 4. Prompt should contextualize the input (e.g., 'Describe this image.', 'Transcribe the audio.').""", | |
| elem_classes="api-key-warning" | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Configuration"): | |
| # Add API Key input field | |
| with gr.Row(): | |
| api_key_input = gr.Textbox( | |
| label="API Key (Optional)", # Simplified label | |
| type="password", | |
| placeholder="Enter key if required for cloud endpoints", | |
| info="Overrides environment variables (e.g., OPENROUTER_API_KEY). Leave blank to use ENV or for local models." | |
| ) | |
| with gr.Row(): | |
| # Champion Model Configuration | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="model-config-group"): | |
| gr.Label("Champion Model (Model A)") | |
| # Updated example for Ollama Mistral 3.1 (as requested default) | |
| champ_name = gr.Textbox(label="Display Name", value="Champion (LM Studio Gemma 3 12B)") | |
| champ_api_url = gr.Textbox(label="API URL", value="http://localhost:1234/v1/chat/completions") # LM Studio OpenAI endpoint | |
| champ_model_id = gr.Textbox(label="Model ID", value="gemma-3-12b-it") # User specified | |
| champ_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=0.1) | |
| champ_max_tokens = gr.Number(label="Max Tokens", value=8192, precision=0) | |
| champ_file_upload_method = gr.Dropdown( | |
| label="File Upload Method", | |
| choices=["JSON (Embedded Data)", "Multipart Form Data"], | |
| value="JSON (Embedded Data)", | |
| info="How to send file data (if any) to this endpoint." | |
| ) | |
| # Challenger Model Configuration | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="model-config-group"): | |
| gr.Label("Challenger Model (Model B)") | |
| # Updated examples for Ollama Gemma 3 27B (as requested default) | |
| chall_name = gr.Textbox(label="Display Name", value="Challenger (LM Studio Gemma 3 4B)") | |
| chall_api_url = gr.Textbox(label="API URL", value="http://localhost:1234/v1/chat/completions") # LM Studio OpenAI endpoint | |
| chall_model_id = gr.Textbox(label="Model ID", value="gemma-3-4b-it") # User specified | |
| chall_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=0.1) | |
| chall_max_tokens = gr.Number(label="Max Tokens", value=8192, precision=0) | |
| chall_file_upload_method = gr.Dropdown( | |
| label="File Upload Method", | |
| choices=["JSON (Embedded Data)", "Multipart Form Data"], | |
| value="JSON (Embedded Data)", | |
| info="How to send file data (if any) to this endpoint." | |
| ) | |
| # Judge Model Configuration | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="model-config-group"): | |
| gr.Label("Judge Model") | |
| judge_name = gr.Textbox(label="Display Name", value="Judge (LM Studio Gemma 3 27B)") | |
| judge_api_url = gr.Textbox(label="API URL", value="http://localhost:1234/v1/chat/completions") # LM Studio OpenAI endpoint | |
| judge_model_id = gr.Textbox(label="Model ID", value="gemma-3-27b-it") # User specified | |
| judge_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.0) # Judge usually deterministic | |
| judge_max_tokens = gr.Number(label="Max Tokens", value=8192, precision=0) # Judge might need more tokens | |
| judge_file_upload_method = gr.Dropdown( | |
| label="File Upload Method", | |
| choices=["JSON (Embedded Data)", "Multipart Form Data"], | |
| value="JSON (Embedded Data)", | |
| info="How to send file data (if any) to this endpoint." | |
| ) | |
| with gr.Row(): | |
| # Model Prompt Template | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Prompt Template") | |
| model_prompt_template_input = gr.Textbox( | |
| label="Template for Champion/Challenger (use {key} for input)", | |
| value="{key}\nUser: Provide a detailed description\nAssistant:", | |
| lines=5, | |
| show_copy_button=True | |
| ) | |
| # Judge Prompt Template | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Judge Prompt Template") | |
| judge_prompt_template_input = gr.Textbox( | |
| label="Template for Judge (see code/docs for available placeholders)", | |
| value=default_judge_prompt, | |
| lines=15, | |
| show_copy_button=True | |
| ) | |
| with gr.Row(): | |
| # Test Data Input | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Test Data") | |
| gr.Markdown("Upload a CSV/JSON/JSONL file or paste data below. Specify the field names containing the model input (key), optional reference answer (value), and optional input path/URL. Add an `id` field for stable identification (recommended).") | |
| test_data_file = gr.File(label="Upload Test Data (CSV, JSON, JSONL/NDJSON)", file_types=[".csv", ".json", ".jsonl", ".ndjson"]) | |
| test_data_text = gr.Textbox(label="Or Paste Test Data (JSON list or JSONL format)", lines=8, placeholder='[{"id": "t1", "prompt": "Describe image", "image_url": "/path/to/img.jpg", "reference": "..."}]\n{"id": "t2", "prompt": "Question text", "image_url": null, "reference": "..."}') | |
| with gr.Row(): | |
| key_field_name_input = gr.Textbox(label="Key Field Name", value="name", info="Field containing the text input/prompt.") # Swapped default | |
| value_field_name_input = gr.Textbox(label="Value Field Name", value="caption", info="Field containing the reference/ground truth (optional).") # Swapped default | |
| image_field_name_input = gr.Textbox(label="Image Field Name", value="image_url", info="Field containing the image path or URL (optional).") # Added image field input | |
| # Test Run Parameters | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Test Parameters") | |
| batch_size_input = gr.Number(label="Batch Size", value=5, minimum=1, precision=0) | |
| # Add batch retry parameters | |
| gr.Markdown("#### Batch Retry Settings") | |
| batch_retry_attempts_input = gr.Number( | |
| label="Batch Retry Attempts", | |
| value=1, # Default to 1 retry | |
| precision=0, | |
| minimum=0, | |
| info="Number of times to retry a batch if trigger strings are found (0 = no retries)" | |
| ) | |
| batch_backoff_factor_input = gr.Slider( | |
| label="Backoff Factor", | |
| value=2.0, | |
| minimum=1.0, | |
| maximum=5.0, | |
| step=0.1, | |
| info="Factor for exponential backoff between retries (e.g., 2.0 waits 1s, 2s, 4s...)" | |
| ) | |
| batch_max_wait_input = gr.Number( | |
| label="Maximum Wait Time (seconds)", | |
| value=60, | |
| precision=0, | |
| minimum=1, | |
| info="Maximum wait time between retries in seconds" | |
| ) | |
| batch_retry_trigger_strings_input = gr.Textbox( | |
| label="Retry Trigger Strings (Comma-separated)", | |
| placeholder="e.g., rate limit,error,timeout,empty response", | |
| info="Retry batch if these strings appear in model/judge output (case-insensitive check)" | |
| ) | |
| # Add preprocessing options later if needed | |
| with gr.TabItem("Monitoring"): | |
| gr.Markdown("### Test Execution & Results") | |
| with gr.Row(): | |
| run_button = gr.Button("Run A/B Test", variant="primary", scale=3) # Adjusted scale | |
| stop_button = gr.Button("Stop Test", variant="stop", scale=1) | |
| with gr.Row(): # New row for results display | |
| with gr.Column(scale=1): # Column for the new status window | |
| with gr.Group(elem_classes="results-box"): | |
| gr.Markdown("#### Last Processed") | |
| # Placeholder for the most recent image | |
| last_image_display = gr.Image(label="Last Image", type="filepath", interactive=False, height=200, value=None, show_label=True) # Explicitly show label, ensure preview area renders | |
| # Placeholder for the runtime of the last image | |
| # Replace single runtime display with three separate ones | |
| last_champ_latency_display = gr.Textbox(label="Champion Latency (s)", value="", interactive=False) | |
| last_chall_latency_display = gr.Textbox(label="Challenger Latency (s)", value="", interactive=False) | |
| last_judge_latency_display = gr.Textbox(label="Judge Latency (s)", value="", interactive=False) | |
| with gr.Column(scale=2): # Column for winner and download | |
| with gr.Group(elem_classes="results-box"): | |
| gr.Markdown("#### Last Winner") | |
| # Textbox to show the last winner | |
| last_winner_output = gr.Textbox(label="Last Test Case Winner", lines=2, show_copy_button=True, interactive=False) | |
| with gr.Group(elem_classes="results-box"): | |
| gr.Markdown("#### Download Results") | |
| # State to hold the raw evaluation results (list of dicts) for download | |
| results_state = gr.State([]) | |
| # Download Button | |
| download_button = gr.DownloadButton(label="Download All Evaluations (JSONL)", value=None) # This button is triggered later | |
| # Add placeholders for final summary and details display | |
| with gr.Group(elem_classes="results-box"): | |
| gr.Markdown("#### Overall Results") | |
| summary_output = gr.Textbox(label="Summary", lines=10, interactive=False) | |
| # Moved Detailed Evaluations DataFrame here | |
| detailed_evaluations_output = gr.DataFrame(label="Individual Case Results", interactive=False) | |
| # Removed the separate "Detailed Evaluations" group | |
| # Define interactions & state | |
| # Define the single, complete run event listener | |
| run_event = run_button.click( | |
| fn=run_test_from_ui, | |
| inputs=[ # Add the new dropdown inputs | |
| champ_name, champ_api_url, champ_model_id, champ_temp, champ_max_tokens, champ_file_upload_method, | |
| chall_name, chall_api_url, chall_model_id, chall_temp, chall_max_tokens, chall_file_upload_method, | |
| judge_name, judge_api_url, judge_model_id, judge_temp, judge_max_tokens, judge_file_upload_method, | |
| api_key_input, | |
| model_prompt_template_input, | |
| judge_prompt_template_input, | |
| test_data_file, | |
| test_data_text, | |
| batch_size_input, | |
| batch_retry_attempts_input, | |
| batch_backoff_factor_input, | |
| batch_max_wait_input, | |
| batch_retry_trigger_strings_input, | |
| key_field_name_input, | |
| value_field_name_input, | |
| image_field_name_input | |
| ], | |
| outputs=[ | |
| # Map the 8 yielded values from run_test_from_ui | |
| summary_output, # 1. Final summary text / Stop message | |
| detailed_evaluations_output, # 2. Final details dataframe / Last partial DF | |
| last_image_display, # 3. Intermediate image path / Last image | |
| last_champ_latency_display, # 4. Intermediate champ latency / Last latency | |
| last_chall_latency_display, # 5. Intermediate chall latency / Last latency | |
| last_judge_latency_display, # 6. Intermediate judge latency / Last latency | |
| last_winner_output, # 7. Intermediate winner text / Last winner | |
| results_state # 8. Hidden state for final/partial raw evaluations | |
| ], | |
| # cancels=[run_event] # Remove self-cancellation | |
| ) | |
| # Now, trigger the download file generation *after* the run completes, using the state | |
| run_event.then( | |
| fn=_generate_download_wrapper, | |
| inputs=[results_state], | |
| outputs=[download_button] # Output the gr.File object to the button itself | |
| ) | |
| # --- Add Stop Button Interaction --- | |
| # Connect the stop button to the request_stop function and make it cancel the run_event | |
| stop_event = stop_button.click( | |
| fn=request_stop, | |
| inputs=None, # request_stop takes no inputs from UI | |
| outputs=None, # request_stop doesn't need to update UI directly anymore | |
| cancels=[run_event] # Make the stop button cancel the main test run | |
| ) | |
| return iface | |
| def run_cli_test(): | |
| """Runs the A/B test from the command line using hardcoded examples.""" | |
| logger.info("Starting CLI execution of ModelTester...") | |
| # --- Configuration (API Key optional for local models) --- | |
| # Load API keys from .env file if it exists | |
| try: | |
| from dotenv import load_dotenv | |
| if load_dotenv(): | |
| logger.info("Loaded environment variables from .env file.") | |
| else: | |
| logger.info(".env file not found or empty, relying on system environment variables or UI input.") | |
| except ImportError: | |
| logger.warning("python-dotenv not installed, cannot load .env file. Run 'pip install python-dotenv' or ensure packages are installed.") | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") # Still useful for judge or cloud models | |
| OLLAMA_API_URL = "http://localhost:11434/api/generate" | |
| # Define Model Endpoints | |
| # Using Ollama Mistral as Champion | |
| champion_model = ModelEndpoint( | |
| name="Champion (Ollama Mistral)", | |
| api_url=OLLAMA_API_URL, | |
| api_key=None, # No key needed for local Ollama | |
| model_id="mistral:latest", # Adjust if your model name is different | |
| temperature=0.1 | |
| ) | |
| # Using Ollama Gemma 2 9B as Challenger | |
| challenger_model = ModelEndpoint( | |
| name="Challenger (Ollama Gemma2 9B)", | |
| api_url=OLLAMA_API_URL, | |
| api_key=None, # No key needed | |
| model_id="hf.co/stduhpf/google-gemma-3-27b-it-qat-q4_0-gguf-small:latest", # Updated to match ollama list | |
| temperature=0.1, | |
| max_tokens=2048 | |
| ) | |
| # Using OpenRouter GPT-4o Mini as Judge (Requires API Key) | |
| if not OPENROUTER_API_KEY: | |
| logger.warning("OPENROUTER_API_KEY not set (checked ENV and .env). Using Champion model as Judge (less ideal).") | |
| judge_model = champion_model # Fallback judge | |
| judge_model.name = "Judge (Fallback - Ollama Mistral)" | |
| else: | |
| logger.info("Using OpenRouter API Key for Judge.") | |
| judge_model = ModelEndpoint( | |
| name="Judge (GPT-4o Mini - OR)", | |
| api_url="https://openrouter.ai/api/v1/chat/completions", | |
| api_key=OPENROUTER_API_KEY, | |
| model_id="openai/gpt-4o-mini", | |
| temperature=0.0, | |
| max_tokens=2048 | |
| ) | |
| # Define Model Prompt Template | |
| model_prompt = "User: {key}\nAssistant:" | |
| # Define Sample Test Cases (Including Multimodal Example) | |
| # Create a dummy image file for testing if it doesn't exist | |
| dummy_image_path = "dummy_test_image.png" | |
| if not os.path.exists(dummy_image_path): | |
| try: | |
| from PIL import Image, ImageDraw | |
| img = Image.new('RGB', (100, 50), color = (73, 109, 137)) # Blueish background | |
| d = ImageDraw.Draw(img) | |
| d.text((10,10), "Test Img", fill=(255,255,0)) # Yellow text | |
| img.save(dummy_image_path) | |
| logger.info(f"Created dummy image file: {dummy_image_path}") | |
| except ImportError: | |
| logger.warning("Pillow (PIL) not installed, cannot create dummy image for CLI test. Ensure packages are installed via venv setup.") | |
| dummy_image_path = None # Cannot use image test case | |
| except Exception as e: | |
| logger.error(f"Failed to create dummy image: {e}") | |
| dummy_image_path = None | |
| test_cases = [ | |
| TestCase(id="q1", key="What is the capital of France?", value="Paris", image_path_or_url=None), | |
| TestCase(id="q2", key="Summarize the plot of the movie 'Inception'.", value="A thief steals information by entering people's dreams.", image_path_or_url=None), | |
| ] | |
| if dummy_image_path: | |
| test_cases.append(TestCase(id="img1", key=f"Describe this image.", value="Blue rectangle with yellow text 'Test Img'", image_path_or_url=dummy_image_path)) | |
| else: | |
| logger.warning("Skipping multimodal test case in CLI as dummy image could not be created.") | |
| # --- Execution --- | |
| try: | |
| # Instantiate the tester | |
| tester = ModelTester( | |
| champion_endpoint=champion_model, | |
| challenger_endpoint=challenger_model, | |
| judge_endpoint=judge_model, | |
| model_prompt_template=model_prompt, | |
| judge_prompt_template=LMJudge.DEFAULT_EVALUATION_PROMPT # Use default judge prompt | |
| ) | |
| logger.info(f"Running CLI test with {len(test_cases)} test cases...") | |
| # Run the test | |
| results_generator = tester.run_test( # Renamed variable for clarity | |
| test_cases, | |
| batch_size=2, | |
| batch_retry_attempts=1, | |
| batch_backoff_factor=2.0, | |
| batch_max_wait=30, | |
| batch_retry_trigger_strings=["rate limit", "error", "timeout"] | |
| ) | |
| # Consume the generator to get the final result | |
| final_results_dict = None | |
| for res in results_generator: | |
| # Assuming intermediate yields might be progress updates (optional logging) | |
| # logger.debug(f"Intermediate result/progress: {res}") | |
| final_results_dict = res # Keep overwriting until the last one | |
| if final_results_dict is None: | |
| logger.error("Test run generator did not yield a final result.") | |
| # Handle error appropriately, maybe raise or return early | |
| # For now, let's just log and proceed, it might fail later anyway | |
| final_results_dict = {} # Assign empty dict to avoid immediate crash below | |
| # --- Output Results --- | |
| logger.info("Test completed. Final Results:") | |
| # Use the formatter function | |
| summary_output = format_summary_output(final_results_dict.get("summary", {})) # Use final_results_dict | |
| print("\n" + summary_output) | |
| # Optionally save full results to JSON | |
| results_filename = f"cli_results_{time.strftime('%Y%m%d-%H%M%S')}.json" | |
| try: | |
| # Need to ensure results are serializable (dataclasses might need conversion) | |
| # The aggregator already converts raw evals to dicts. Summary should be fine. | |
| with open(results_filename, 'w', encoding='utf-8') as f: | |
| json.dump(final_results_dict, f, indent=2, ensure_ascii=False) # Use final_results_dict | |
| logger.info(f"Full results saved to {results_filename}") | |
| print(f"\nFull results saved to: {results_filename}") | |
| except TypeError as e: | |
| logger.error(f"Failed to save results to JSON due to serialization issue: {e}") | |
| print(f"\nWarning: Could not save full results to JSON: {e}") | |
| except Exception as e: | |
| logger.error(f"Failed to save results to JSON: {e}") | |
| print(f"\nWarning: Could not save full results to JSON: {e}") | |
| except Exception as e: | |
| logger.exception("An error occurred during the CLI execution.") | |
| print(f"\nAn error occurred during CLI execution: {e}") | |
| finally: | |
| # Clean up dummy image if created and path exists | |
| if dummy_image_path and os.path.exists(dummy_image_path): | |
| try: | |
| os.remove(dummy_image_path) | |
| logger.info(f"Removed dummy image file: {dummy_image_path}") | |
| except Exception as e: | |
| logger.warning(f"Could not remove dummy image file {dummy_image_path}: {e}") | |
| # ============================================================================== | |
| # Main Execution Logic | |
| # ============================================================================== | |
| def main(): | |
| """Main function to parse arguments and run either CLI or UI.""" | |
| # Basic argument parsing: run CLI test by default, or launch UI with --ui flag | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Model A/B Testing Tool") | |
| parser.add_argument("--ui", action="store_true", help="Launch the Gradio web UI instead of running the CLI test.") | |
| args = parser.parse_args() | |
| if args.ui: | |
| logger.info("Launching Gradio UI...") | |
| iface = create_ui() | |
| if iface: | |
| # Add share=True for public link if needed, auth=("user", "pass") for security | |
| # Add server_name="0.0.0.0" to listen on all interfaces if running in Docker/remote | |
| iface.launch(share=True) | |
| else: | |
| logger.error("Failed to create Gradio UI.") | |
| print("Error: Could not create the Gradio UI.") | |
| else: | |
| # Set up signal handler for CLI stop (Ctrl+C) | |
| def signal_handler(sig, frame): | |
| global STOP_REQUESTED | |
| if not STOP_REQUESTED: | |
| print("\nCtrl+C detected. Requesting stop after current batch...") | |
| logger.warning("Stop requested via Ctrl+C.") | |
| STOP_REQUESTED = True | |
| else: | |
| # Allow force exit on second Ctrl+C | |
| print("\nCtrl+C detected again. Forcing exit.") | |
| logger.error("Forced exit via second Ctrl+C.") | |
| sys.exit(1) | |
| signal.signal(signal.SIGINT, signal_handler) | |
| # Run the command-line test | |
| run_cli_test() | |
| if __name__ == "__main__": | |
| # Ensure we are running in the correct virtual environment | |
| # ensure_venv() will handle creation, installation, and re-execution if necessary. | |
| # If ensure_venv() returns True, it means we are now in the correct venv. | |
| if ensure_venv(): | |
| # Now that we are confirmed to be in the venv, execute the main logic | |
| main() | |
| # If ensure_venv() returned False (or exited), the script either failed or restarted itself. | |