File size: 13,072 Bytes
8f5972f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c630c
8f5972f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import os
import orjson
import concurrent.futures
import random
import torch
import threading
import time
import uuid
import glob
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download, hf_hub_download, HfApi

from riichienv import RiichiEnv, GameRule

# 分别导入两个不同架构的加载函数,防止命名冲突
from model3pLOCAL import load_model as load_model_local
from model3pNEW import load_model as load_model_new

# ==========================================
# 0. 核心对抗配置开关 (在这里切换模式)
# ==========================================
# True: 1个 NEW架构(TEST_MODEL) VS 2个 LOCAL架构(EXAMINER_MODEL)
# False: 1个 LOCAL架构(TEST_MODEL) VS 2个 NEW架构(EXAMINER_MODEL)
ONE_NEW_VS_TWO_LOCAL = True  

# ==========================================
# 0. 分布式多开与云端持久化配置
# ==========================================
DATA_REPO_ID = "ffzeroHua/mj-eval-results"  # 📊 战绩数据集仓库
MODEL_REPO_ID = "ffzeroHua/Riichi-Model-Repo" # 🧠 模型权重仓库
HF_TOKEN = os.getenv("HF_TOKEN")

# 为当前节点生成唯一的 ID
WORKER_ID = os.getenv("WORKER_ID", str(uuid.uuid4())[:6])

# 根据开关状态自动调整保存的文件前缀
BASE_REPORT_PREFIX = 'D57k_vs_9070_eval_report'
if ONE_NEW_VS_TWO_LOCAL:
    REPORT_FILE_PREFIX = BASE_REPORT_PREFIX
else:
    REPORT_FILE_PREFIX = f"inverse_{BASE_REPORT_PREFIX}"

REPORT_FILE = f"{REPORT_FILE_PREFIX}_{WORKER_ID}.txt"

api = HfApi()
EVAL_RUNNING = True

# 🚀 设定要从云端拉取并进行对抗的两个模型
TEST_MODEL = "StudentSanma_Distilled_Step57000.pth"
EXAMINER_MODEL = "Elite4z9070.pth"

def sync_models_from_hub():
    """启动时从指定的模型仓库拉取对战双方的权重文件"""
    if HF_TOKEN and "你的用户名" not in MODEL_REPO_ID:
        print(f"☁️ 正在从模型仓库 [{MODEL_REPO_ID}] 拉取评估模型...")
        try:
            hf_hub_download(repo_id=MODEL_REPO_ID, filename=TEST_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
            print(f"✅ 成功拉取测试模型: {TEST_MODEL}")
            
            hf_hub_download(repo_id=MODEL_REPO_ID, filename=EXAMINER_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
            print(f"✅ 成功拉取考官模型: {EXAMINER_MODEL}")
            
            print("🎉 模型环境准备完毕!")
        except Exception as e:
            print(f"❌ 拉取模型失败,请检查文件名或仓库权限: {e}")
    else:
        print("⚠️ 未配置有效 HF_TOKEN 或未修改 MODEL_REPO_ID,将尝试使用本地已存在的模型文件。")

def sync_data_from_hub():
    """启动时从数据集下载所有节点的战绩分片文件"""
    if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
        try:
            print(f"🔄 正在从 Hub 拉取全局历史战绩数据 (前缀匹配: {REPORT_FILE_PREFIX})...")
            snapshot_download(
                repo_id=DATA_REPO_ID,
                repo_type="dataset",
                local_dir=".",
                allow_patterns=REPORT_FILE_PREFIX + "_*.txt",
                token=HF_TOKEN
            )
            print("✅ 历史数据拉取完成。")
        except Exception as e:
            print(f"⚠️ 拉取历史战绩失败: {e}")

def sync_data_to_hub():
    """将当前节点的战绩文件备份到数据集"""
    if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
        try:
            api.upload_file(
                path_or_fileobj=REPORT_FILE,
                path_in_repo=REPORT_FILE,
                repo_id=DATA_REPO_ID,
                repo_type="dataset",
                token=HF_TOKEN
            )
            print(f"☁️ 节点 {WORKER_ID} 战绩已同步至 Hub: {time.strftime('%H:%M:%S')}")
        except Exception as e:
            print(f"❌ 同步失败: {e}")

# ==========================================
# 1. 高频及模型加载逻辑
# ==========================================
def patch_event_fast(event_str):
    if '"kita"' in event_str:
        event_str = event_str.replace('"kita"', '"nukidora"')
    
    if '"start_kyoku"' in event_str or '"deltas"' in event_str:
        event = orjson.loads(event_str)
        if event.get('type') == 'start_kyoku':
            scores = event.setdefault('scores', [])
            while len(scores) < 4: scores.append(0)
            tehais = event.setdefault('tehais', [])
            while len(tehais) < 4: tehais.append(["?" for _ in range(13)])
        if 'deltas' in event:
            deltas = event['deltas']
            while len(deltas) < 4: deltas.append(0)
        return orjson.dumps(event).decode('utf-8')
    return event_str

def patch_resp_fast(resp_str):
    if not resp_str: return resp_str
    return resp_str.replace('"nukidora"', '"kita"')

_MODEL_CACHE = {}

def get_cached_model(player_id: int, model_file: str, arch_type: str):
    """根据指定的架构类型 (new 或 local) 加载模型"""
    key = (player_id, model_file, arch_type)
    if key not in _MODEL_CACHE:
        torch.set_num_threads(1) 
        if arch_type == 'new':
            _MODEL_CACHE[key] = load_model_new(player_id, model_file)
        else:
            _MODEL_CACHE[key] = load_model_local(player_id, model_file)
    return _MODEL_CACHE[key]

class MortalAgent:
    def __init__(self, player_id: int, model_file: str, arch_type: str):
        self.player_id = player_id
        self.arch_type = arch_type
        self.model = get_cached_model(player_id, model_file, arch_type)

    def act(self, obs):
        resp = None
        for event in obs.new_events():
            event_patched = patch_event_fast(event)
            resp = patch_resp_fast(self.model.react(event_patched))
        action = obs.select_action_from_mjai(resp)
        assert action is not None, "Mortal must return a legal action"
        return action

# ==========================================
# 2. 核心对局任务
# ==========================================
def play_one_game(game_index):
    env = RiichiEnv(game_mode="3p-red-half", rule=GameRule.default_tenhou())
    new_seat = random.randrange(3)
    
    agents = {}
    for i in range(3):
        if i == new_seat:
            # 🚀 挑战者位
            model_file = TEST_MODEL
            arch = 'new' if ONE_NEW_VS_TWO_LOCAL else 'local'
        else:
            # 🚀 考官位
            model_file = EXAMINER_MODEL
            arch = 'local' if ONE_NEW_VS_TWO_LOCAL else 'new'
            
        agents[i] = MortalAgent(i, model_file, arch)
    
    obs_dict = env.reset()
    while not env.done():
        actions = {pid: agents[pid].act(obs) for pid, obs in obs_dict.items()}
        obs_dict = env.step(actions)
        
    scores = env.scores()
    ranks = env.ranks()
    return ranks[new_seat], scores[new_seat]

# ==========================================
# 3. 后台独立评估线程
# ==========================================
def background_eval_loop():
    sync_models_from_hub() # 🚀 启动时从 Riichi-Model-Repo 拉取对战模型
    sync_data_from_hub()   # 🚀 启动时从战绩仓库拉取历史战绩
    
    NUM_WORKERS = 1 
    
    mode_str = "1只 NEW 挑战 2只 LOCAL" if ONE_NEW_VS_TWO_LOCAL else "1只 LOCAL 挑战 2只 NEW"
    print(f"🚀 节点 [{WORKER_ID}] 后台对战线程已启动: 模式为 [{mode_str}]")
    
    if not os.path.exists(REPORT_FILE):
        open(REPORT_FILE, 'w').close()

    games_since_last_sync = 0

    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
        futures = {executor.submit(play_one_game, i) for i in range(NUM_WORKERS * 2)}
        games_completed = 0
        
        while EVAL_RUNNING and futures:
            done, futures = concurrent.futures.wait(
                futures, return_when=concurrent.futures.FIRST_COMPLETED
            )
            
            with open(REPORT_FILE, "a") as f:
                for future in done:
                    try:
                        rank, score = future.result()
                        f.write(f"{rank} {score}\n")
                        f.flush()
                        games_completed += 1
                        games_since_last_sync += 1
                        print(f"[节点 {WORKER_ID}] 完成 {games_completed} 局: 顺位 {rank}, 得点 {score}")
                    except Exception as e:
                        print(f"对局异常: {e}")
                    
                    if EVAL_RUNNING:
                        futures.add(executor.submit(play_one_game, games_completed))
            
            if games_since_last_sync >= 100:
                sync_data_to_hub()
                sync_data_from_hub()
                games_since_last_sync = 0

# ==========================================
# 4. 前端 Gradio 实时展示面板 (全局汇总)
# ==========================================
def read_and_analyze():
    all_files = glob.glob(f"{REPORT_FILE_PREFIX}_*.txt")
    
    main_arch = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
    opp_arch = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
    
    if not all_files:
        return f"⏳ 正在拉取模型并等待 [{main_arch}] `{TEST_MODEL}` VS [{opp_arch}] `{EXAMINER_MODEL}` 第一局完成...", None
    
    ranks, scores = [], []
    try:
        for file in all_files:
            with open(file, "r") as f:
                lines = f.readlines()
                for line in lines:
                    parts = line.strip().split()
                    if len(parts) == 2:
                        ranks.append(int(float(parts[0])))
                        scores.append(float(parts[1]))
        total = len(ranks)
        if total == 0:
            return f"⏳ 模型已就绪,正在进行第一局对抗...", None

        avg_rank = sum(ranks) / total
        avg_score = sum(scores) / total
        rank1_rate = ranks.count(1) / total * 100
        rank2_rate = ranks.count(2) / total * 100
        rank3_rate = ranks.count(3) / total * 100

        last_update = time.strftime('%Y-%m-%d %H:%M:%S')

        md_text = f"""
        ### 📊 对战简报
        - ⚔️ **对抗阵容:** 1只 `{TEST_MODEL}` ({main_arch}) **VS** 2只 `{EXAMINER_MODEL}` ({opp_arch})
        - 🧮 **总对局数:** {total} 局 (跨节点全局汇集)
        - 🏆 **平均顺位:** {avg_rank:.3f}
        - 💰 **平均得点:** {avg_score:.0f}
        ---
        - 🥇 **一位率:** {rank1_rate:.1f}%
        - 🥈 **二位率:** {rank2_rate:.1f}%
        - 🥉 **三位率:** {rank3_rate:.1f}%
        ---
        - 🌐 **当前节点 ID:** `{WORKER_ID}`
        - 🕒 **刷新时间:** {last_update}
        """

        fig = plt.figure(figsize=(10, 4))
        
        ax1 = fig.add_subplot(121)
        ax1.bar(['1st', '2nd', '3rd'], [rank1_rate, rank2_rate, rank3_rate], color=['#FFD700', '#C0C0C0', '#CD7F32'])
        ax1.set_title(f'Rank Distribution for {TEST_MODEL}')
        ax1.set_ylim(0, max(100, max([rank1_rate, rank2_rate, rank3_rate] + [0]) + 10))
        for i, v in enumerate([rank1_rate, rank2_rate, rank3_rate]):
            ax1.text(i, v + 2, f"{v:.1f}%", ha='center')

        ax2 = fig.add_subplot(122)
        df = pd.DataFrame({'score': scores})
        df['ma'] = df['score'].rolling(window=min(10, max(1, len(df))), min_periods=1).mean()
        ax2.plot(df['score'], alpha=0.3, color='gray', label='Raw Score')
        ax2.plot(df['ma'], color='crimson', linewidth=2, label='Moving Avg (10)')
        ax2.set_title('Score Trend')
        ax2.legend()

        plt.tight_layout()
        return md_text, fig
        
    except Exception as e:
        return f"❌ 数据解析出错: {e}", None

# ==========================================
# 5. 启动 Gradio 应用
# ==========================================
with gr.Blocks() as demo:
    gr.Markdown("# 🀄 Mahjong AI 基准评估舱")
    
    header_main = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
    header_opp = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
    
    gr.Markdown(f"当前正在评估: 1名 **{TEST_MODEL} ({header_main})** 单挑 2名 **{EXAMINER_MODEL} ({header_opp})**。启动时会自动拉取权重。")
    
    with gr.Row():
        with gr.Column(scale=1):
            stats_output = gr.Markdown("🚀 正在初始化基准环境并连接模型仓库...")
            refresh_btn = gr.Button("🔄 手动刷新全局战绩")
        with gr.Column(scale=2):
            plot_output = gr.Plot()

    demo.load(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
    
    timer = gr.Timer(15)
    timer.tick(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
    
    refresh_btn.click(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])

if __name__ == "__main__":
    t = threading.Thread(target=background_eval_loop, daemon=True)
    t.start()
    
    demo.queue().launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())