Spaces:
Sleeping
Sleeping
| import sys, os, types, json | |
| from unittest.mock import patch, MagicMock | |
| # Ensure project root in path | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| # Provide dummy litellm module so import succeeds | |
| fake_litellm = types.ModuleType('litellm') | |
| fake_litellm.completion = MagicMock() | |
| sys.modules.setdefault('litellm', fake_litellm) | |
| # Provide dummy dotenv module | |
| fake_dotenv = types.ModuleType('dotenv') | |
| fake_dotenv.load_dotenv = MagicMock() | |
| sys.modules.setdefault('dotenv', fake_dotenv) | |
| # Dummy gradio module so import succeeds | |
| fake_gradio = types.ModuleType('gradio') | |
| fake_gradio.Interface = MagicMock(return_value=MagicMock(launch=MagicMock())) | |
| fake_gradio.Textbox = MagicMock | |
| fake_gradio.Number = MagicMock | |
| fake_gradio.Checkbox = MagicMock | |
| fake_gradio.Plot = MagicMock | |
| sys.modules.setdefault('gradio', fake_gradio) | |
| # Dummy tqdm module for write method | |
| class FakeTqdmModule(types.ModuleType): | |
| def __init__(self): | |
| super().__init__('tqdm') | |
| self.write = MagicMock() | |
| def __call__(self, iterable=None, total=None): | |
| return iterable | |
| fake_tqdm_mod = FakeTqdmModule() | |
| fake_tqdm_mod.tqdm = fake_tqdm_mod | |
| sys.modules.setdefault('tqdm', fake_tqdm_mod) | |
| # Dummy matplotlib module | |
| fake_plt = types.ModuleType('matplotlib.pyplot') | |
| fake_plt.figure = MagicMock(return_value='fig') | |
| fake_plt.hist = MagicMock() | |
| fake_plt.bar = MagicMock() | |
| fake_plt.xticks = MagicMock() | |
| fake_matplotlib = types.ModuleType('matplotlib') | |
| fake_matplotlib.pyplot = fake_plt | |
| sys.modules.setdefault('matplotlib', fake_matplotlib) | |
| sys.modules.setdefault('matplotlib.pyplot', fake_plt) | |
| import main | |
| class DummyFuture: | |
| def __init__(self, func, *args): | |
| self._func = func | |
| self._args = args | |
| def result(self): | |
| return self._func(*self._args) | |
| class DummyExecutor: | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| pass | |
| def submit(self, func, *args): | |
| return DummyFuture(func, *args) | |
| def map(self, func, iterable): | |
| for item in iterable: | |
| yield func(item) | |
| class DummyTqdm: | |
| def __call__(self, iterable=None, total=None): | |
| return iterable | |
| def write(self, msg): | |
| pass | |
| def test_run_tournament_full_loop(): | |
| dummy_tqdm = DummyTqdm() | |
| with patch('main.generate_players') as mock_gen, \ | |
| patch('main.prompt_score') as mock_score, \ | |
| patch('main.prompt_pairwise') as mock_pair, \ | |
| patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockExec, \ | |
| patch('main.as_completed', new=lambda futs: futs), \ | |
| patch('main.tqdm', new=dummy_tqdm), \ | |
| patch('main.plt.figure', return_value='fig'), \ | |
| patch('main.plt.hist'), \ | |
| patch('main.plt.bar'): | |
| mock_gen.return_value = (['p1', 'p2', 'p3', 'p4'], {'prompt_tokens':1,'completion_tokens':1}) | |
| scores = {'p1':3, 'p2':2, 'p3':1, 'p4':0} | |
| mock_score.side_effect = lambda instr, cl, block, player, **kw: ( | |
| f"Final verdict: [{scores[player]}]", | |
| {'prompt_tokens':1,'completion_tokens':1} | |
| ) | |
| mock_pair.side_effect = lambda instr, block, a, b, **kw: ( | |
| "Final verdict: A", | |
| {'prompt_tokens':1,'completion_tokens':1} | |
| ) | |
| results = list(main.run_tournament( | |
| api_base='b', | |
| api_token='k', | |
| generate_model='gm', | |
| score_model='sm', | |
| pairwise_model='pm', | |
| generate_temperature=1, | |
| score_temperature=1, | |
| pairwise_temperature=1, | |
| instruction_input='instr', | |
| criteria_input='c1,c2', | |
| n_gen=4, | |
| pool_size=2, | |
| num_top_picks=1, | |
| max_workers=1, | |
| enable_score_filter=True, | |
| enable_pairwise_filter=True, | |
| score_with_instruction=True, | |
| pairwise_with_instruction=True, | |
| generate_thinking=True, | |
| score_thinking=True, | |
| pairwise_thinking=True, | |
| )) | |
| process_log, hist_fig, elo_fig, top_picks, usage = results[-1] | |
| assert 'Done' in process_log | |
| assert hist_fig == 'fig' | |
| assert elo_fig == 'fig' | |
| assert any(p in top_picks for p in {'p1', 'p2'}) | |
| mock_gen.assert_called_once_with('instr', 4, model='gm', api_base='b', api_key='k', temperature=1, thinking=True, return_usage=True) | |
| assert 'Score completion' in process_log | |
| assert 'Pairwise completion' in process_log | |
| assert 'Prompt tokens' in usage | |
| assert mock_score.call_count == 4 | |
| assert mock_pair.called | |
| def test_run_tournament_pairwise_odd_players(): | |
| dummy_tqdm = DummyTqdm() | |
| with patch('main.generate_players') as mock_gen, \ | |
| patch('main.prompt_pairwise') as mock_pair, \ | |
| patch('main.ThreadPoolExecutor', return_value=DummyExecutor()) as MockEx, \ | |
| patch('main.as_completed', new=lambda futs: futs), \ | |
| patch('main.tqdm', new=dummy_tqdm), \ | |
| patch('main.plt.figure', return_value='fig'), \ | |
| patch('main.plt.hist'), \ | |
| patch('main.plt.bar'): | |
| mock_gen.return_value = (['p1', 'p2', 'p3'], {'prompt_tokens':1,'completion_tokens':1}) | |
| mock_pair.side_effect = lambda instr, block, a, b, **kw: ( | |
| "Final verdict: A", | |
| {'prompt_tokens':1,'completion_tokens':1} | |
| ) | |
| results = list(main.run_tournament( | |
| api_base='b', | |
| api_token='k', | |
| generate_model='gm', | |
| score_model='sm', | |
| pairwise_model='pm', | |
| generate_temperature=1, | |
| score_temperature=1, | |
| pairwise_temperature=1, | |
| instruction_input='instr', | |
| criteria_input='c1,c2', | |
| n_gen=3, | |
| pool_size=3, | |
| num_top_picks=1, | |
| max_workers=1, | |
| enable_score_filter=False, | |
| enable_pairwise_filter=True, | |
| score_with_instruction=True, | |
| pairwise_with_instruction=True, | |
| generate_thinking=True, | |
| score_thinking=True, | |
| pairwise_thinking=True, | |
| )) | |
| process_log, hist_fig, elo_fig, top_picks, usage = results[-1] | |
| assert 'Done' in process_log | |
| assert any(p in top_picks for p in {'p1', 'p2', 'p3'}) | |
| assert mock_pair.call_count == 3 | |