File size: 5,724 Bytes
f6b4b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import ollama
import time
import json
from rich.console import Console
from rich.panel import Panel
from benchmarks.benchmark_suite import BenchmarkSuite
from typing import Dict, Any, List, Tuple, Generator

console = Console()

def run_full_benchmark(model_name: str, judge_model: str, num_iterations: int = 1):
    """
    Runs the full benchmark with real-time progress and Q&A output using rich.
    """
    console.print(Panel("[bold magenta]LLM Full Benchmark Test[/bold magenta]", expand=False))
    console.print(f"[bold blue]Running full benchmark: {model_name} vs {judge_model}[/bold blue]")

    benchmark_suite = BenchmarkSuite(model_name, judge_model)
    results = {}

    # Test categories with their display names
    test_categories = [
        ("Logical Reasoning", "test_logical_reasoning"),
        ("Code Generation", "test_code_generation"),
        ("Mathematical Problem Solving", "test_math_solving"),
        ("Context Understanding", "test_context_understanding"),
        ("Performance Metrics", "test_performance")
    ]

    # Store test cases data
    test_cases = []
    max_test_cases = 5  # Limit the number of test cases displayed
    
    # Track the last update time for each test case to throttle updates
    last_updates = {}
    update_interval = 0.5  # Minimum seconds between updates per test case
    
    # Callback function to output Q&A information
    def update_qa_output_callback(prompt: str, model_response: str, judge_response: str, model_name: str, judge_model_name: str):
        # Create a key for this test case
        row_key = f"{model_name}_{hash(prompt) % 1000}"
        
        # Check if we should update based on throttling
        current_time = time.time()
        if row_key in last_updates:
            time_since_last = current_time - last_updates[row_key]
            if time_since_last < update_interval:
                # Skip update if not enough time has passed
                return
        last_updates[row_key] = current_time
        
        # Check if this test case already exists
        existing_case = None
        for i, case in enumerate(test_cases):
            if case.get('key') == row_key:
                existing_case = case
                existing_case_index = i
                break
        
        if existing_case is None:
            # Add a new test case if we haven't reached the limit
            if len(test_cases) < max_test_cases:
                test_cases.append({
                    'key': row_key,
                    'model_name': model_name,
                    'prompt': prompt,
                    'model_response': model_response,
                    'judge_response': judge_response
                })
            else:
                # If we've reached the limit, update the oldest test case
                test_cases.pop(0)  # Remove the oldest
                test_cases.append({
                    'key': row_key,
                    'model_name': model_name,
                    'prompt': prompt,
                    'model_response': model_response,
                    'judge_response': judge_response
                })
        else:
            # Update the existing test case
            existing_case['model_response'] = model_response
            existing_case['judge_response'] = judge_response
        
        # Output the Q&A information in rich text
        console.print(f"[bold blue]Model:[/bold blue] {model_name}")
        console.print(f"[bold cyan]Prompt:[/bold cyan] {prompt}")
        console.print(f"[bold green]Response:[/bold green] {model_response}")
        console.print(f"[bold yellow]Judge:[/bold yellow] {judge_response}")
        console.print("-" * 50)  # Separator line

    for i, (category_name, method_name) in enumerate(test_categories):
        console.print(f"[bold green]Running {category_name} Benchmark...[/bold green]")
        
        # Show model loading/processing
        console.print(f"[magenta]  Loading models for {category_name}...[/magenta]")
        time.sleep(0.5) # Simulate loading time

        try:
            # Run the actual test, passing the callback
            test_func = getattr(benchmark_suite, method_name)
            result = test_func(num_iterations, update_qa_output_callback) # Pass the new callback
            results[category_name] = result
            
            console.print(f"[green]✓ {category_name} completed: {result.get('score', 0):.1f}/10[/green]")
        except Exception as e:
            console.print(f"[red]✗ {category_name} failed: {str(e)}[/red]")
            results[category_name] = {"score": 0, "error": str(e)}

    # Calculate final scores
    overall_score = sum(r.get("score", 0) for r in results.values()) / len(results) if results else 0

    # Print summary
    console.print(Panel("[bold magenta]Benchmark Results Summary[/bold magenta]", expand=False))
    for test_name, result in results.items():
        score = result.get('score', 0)
        if 'error' in result:
            console.print(f"  {test_name}: [red]Error - {result['error']}[/red]")
        else:
            console.print(f"  {test_name}: {score:.1f}/10")
    console.print(f"[bold blue]Overall Score: {overall_score:.1f}/10[/bold blue]")

    return results

if __name__ == "__main__":
    import sys
    
    model_to_test = "qwen3:8b"
    judge_model = "deepscaler:latest"
    iterations = 1

    if len(sys.argv) > 1:
        if sys.argv[1] == "detailed":
            run_full_benchmark(model_to_test, judge_model, iterations)
        else:
            console.print("[red]Invalid argument. Use 'python test_benchmark.py detailed'[/red]")
    else:
        run_full_benchmark(model_to_test, judge_model, iterations)