NewContest6 / app.py
ffzeroHua's picture
Update app.py
2275b82 verified
import os
import orjson
import concurrent.futures
import random
import torch
import threading
import time
import uuid
import glob
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download, hf_hub_download, HfApi
from riichienv import RiichiEnv, GameRule
# 分别导入两个不同架构的加载函数,防止命名冲突
from model3pLOCAL import load_model as load_model_local
from model3pNEW import load_model as load_model_new
# ==========================================
# 0. 核心对抗配置开关 (在这里切换模式)
# ==========================================
# True: 1个 NEW架构(TEST_MODEL) VS 2个 LOCAL架构(EXAMINER_MODEL)
# False: 1个 LOCAL架构(TEST_MODEL) VS 2个 NEW架构(EXAMINER_MODEL)
ONE_NEW_VS_TWO_LOCAL = True
# ==========================================
# 0. 分布式多开与云端持久化配置
# ==========================================
DATA_REPO_ID = "ffzeroHua/mj-eval-results" # 📊 战绩数据集仓库
MODEL_REPO_ID = "ffzeroHua/Riichi-Model-Repo" # 🧠 模型权重仓库
HF_TOKEN = os.getenv("HF_TOKEN")
# 为当前节点生成唯一的 ID
WORKER_ID = os.getenv("WORKER_ID", str(uuid.uuid4())[:6])
# 根据开关状态自动调整保存的文件前缀
BASE_REPORT_PREFIX = 'D57k_vs_9070_eval_report'
if ONE_NEW_VS_TWO_LOCAL:
REPORT_FILE_PREFIX = BASE_REPORT_PREFIX
else:
REPORT_FILE_PREFIX = f"inverse_{BASE_REPORT_PREFIX}"
REPORT_FILE = f"{REPORT_FILE_PREFIX}_{WORKER_ID}.txt"
api = HfApi()
EVAL_RUNNING = True
# 🚀 设定要从云端拉取并进行对抗的两个模型
TEST_MODEL = "StudentSanma_Distilled_Step57000.pth"
EXAMINER_MODEL = "Elite4z9070.pth"
def sync_models_from_hub():
"""启动时从指定的模型仓库拉取对战双方的权重文件"""
if HF_TOKEN and "你的用户名" not in MODEL_REPO_ID:
print(f"☁️ 正在从模型仓库 [{MODEL_REPO_ID}] 拉取评估模型...")
try:
hf_hub_download(repo_id=MODEL_REPO_ID, filename=TEST_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
print(f"✅ 成功拉取测试模型: {TEST_MODEL}")
hf_hub_download(repo_id=MODEL_REPO_ID, filename=EXAMINER_MODEL, repo_type="model", local_dir=".", token=HF_TOKEN)
print(f"✅ 成功拉取考官模型: {EXAMINER_MODEL}")
print("🎉 模型环境准备完毕!")
except Exception as e:
print(f"❌ 拉取模型失败,请检查文件名或仓库权限: {e}")
else:
print("⚠️ 未配置有效 HF_TOKEN 或未修改 MODEL_REPO_ID,将尝试使用本地已存在的模型文件。")
def sync_data_from_hub():
"""启动时从数据集下载所有节点的战绩分片文件"""
if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
try:
print(f"🔄 正在从 Hub 拉取全局历史战绩数据 (前缀匹配: {REPORT_FILE_PREFIX})...")
snapshot_download(
repo_id=DATA_REPO_ID,
repo_type="dataset",
local_dir=".",
allow_patterns=REPORT_FILE_PREFIX + "_*.txt",
token=HF_TOKEN
)
print("✅ 历史数据拉取完成。")
except Exception as e:
print(f"⚠️ 拉取历史战绩失败: {e}")
def sync_data_to_hub():
"""将当前节点的战绩文件备份到数据集"""
if HF_TOKEN and "你的用户名" not in DATA_REPO_ID:
try:
api.upload_file(
path_or_fileobj=REPORT_FILE,
path_in_repo=REPORT_FILE,
repo_id=DATA_REPO_ID,
repo_type="dataset",
token=HF_TOKEN
)
print(f"☁️ 节点 {WORKER_ID} 战绩已同步至 Hub: {time.strftime('%H:%M:%S')}")
except Exception as e:
print(f"❌ 同步失败: {e}")
# ==========================================
# 1. 高频及模型加载逻辑
# ==========================================
def patch_event_fast(event_str):
if '"kita"' in event_str:
event_str = event_str.replace('"kita"', '"nukidora"')
if '"start_kyoku"' in event_str or '"deltas"' in event_str:
event = orjson.loads(event_str)
if event.get('type') == 'start_kyoku':
scores = event.setdefault('scores', [])
while len(scores) < 4: scores.append(0)
tehais = event.setdefault('tehais', [])
while len(tehais) < 4: tehais.append(["?" for _ in range(13)])
if 'deltas' in event:
deltas = event['deltas']
while len(deltas) < 4: deltas.append(0)
return orjson.dumps(event).decode('utf-8')
return event_str
def patch_resp_fast(resp_str):
if not resp_str: return resp_str
return resp_str.replace('"nukidora"', '"kita"')
_MODEL_CACHE = {}
def get_cached_model(player_id: int, model_file: str, arch_type: str):
"""根据指定的架构类型 (new 或 local) 加载模型"""
key = (player_id, model_file, arch_type)
if key not in _MODEL_CACHE:
torch.set_num_threads(1)
if arch_type == 'new':
_MODEL_CACHE[key] = load_model_new(player_id, model_file)
else:
_MODEL_CACHE[key] = load_model_local(player_id, model_file)
return _MODEL_CACHE[key]
class MortalAgent:
def __init__(self, player_id: int, model_file: str, arch_type: str):
self.player_id = player_id
self.arch_type = arch_type
self.model = get_cached_model(player_id, model_file, arch_type)
def act(self, obs):
resp = None
for event in obs.new_events():
event_patched = patch_event_fast(event)
resp = patch_resp_fast(self.model.react(event_patched))
action = obs.select_action_from_mjai(resp)
assert action is not None, "Mortal must return a legal action"
return action
# ==========================================
# 2. 核心对局任务
# ==========================================
def play_one_game(game_index):
env = RiichiEnv(game_mode="3p-red-half", rule=GameRule.default_tenhou())
new_seat = random.randrange(3)
agents = {}
for i in range(3):
if i == new_seat:
# 🚀 挑战者位
model_file = TEST_MODEL
arch = 'new' if ONE_NEW_VS_TWO_LOCAL else 'local'
else:
# 🚀 考官位
model_file = EXAMINER_MODEL
arch = 'local' if ONE_NEW_VS_TWO_LOCAL else 'new'
agents[i] = MortalAgent(i, model_file, arch)
obs_dict = env.reset()
while not env.done():
actions = {pid: agents[pid].act(obs) for pid, obs in obs_dict.items()}
obs_dict = env.step(actions)
scores = env.scores()
ranks = env.ranks()
return ranks[new_seat], scores[new_seat]
# ==========================================
# 3. 后台独立评估线程
# ==========================================
def background_eval_loop():
sync_models_from_hub() # 🚀 启动时从 Riichi-Model-Repo 拉取对战模型
sync_data_from_hub() # 🚀 启动时从战绩仓库拉取历史战绩
NUM_WORKERS = 1
mode_str = "1只 NEW 挑战 2只 LOCAL" if ONE_NEW_VS_TWO_LOCAL else "1只 LOCAL 挑战 2只 NEW"
print(f"🚀 节点 [{WORKER_ID}] 后台对战线程已启动: 模式为 [{mode_str}]")
if not os.path.exists(REPORT_FILE):
open(REPORT_FILE, 'w').close()
games_since_last_sync = 0
with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
futures = {executor.submit(play_one_game, i) for i in range(NUM_WORKERS * 2)}
games_completed = 0
while EVAL_RUNNING and futures:
done, futures = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED
)
with open(REPORT_FILE, "a") as f:
for future in done:
try:
rank, score = future.result()
f.write(f"{rank} {score}\n")
f.flush()
games_completed += 1
games_since_last_sync += 1
print(f"[节点 {WORKER_ID}] 完成 {games_completed} 局: 顺位 {rank}, 得点 {score}")
except Exception as e:
print(f"对局异常: {e}")
if EVAL_RUNNING:
futures.add(executor.submit(play_one_game, games_completed))
if games_since_last_sync >= 100:
sync_data_to_hub()
sync_data_from_hub()
games_since_last_sync = 0
# ==========================================
# 4. 前端 Gradio 实时展示面板 (全局汇总)
# ==========================================
def read_and_analyze():
all_files = glob.glob(f"{REPORT_FILE_PREFIX}_*.txt")
main_arch = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
opp_arch = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
if not all_files:
return f"⏳ 正在拉取模型并等待 [{main_arch}] `{TEST_MODEL}` VS [{opp_arch}] `{EXAMINER_MODEL}` 第一局完成...", None
ranks, scores = [], []
try:
for file in all_files:
with open(file, "r") as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if len(parts) == 2:
ranks.append(int(float(parts[0])))
scores.append(float(parts[1]))
total = len(ranks)
if total == 0:
return f"⏳ 模型已就绪,正在进行第一局对抗...", None
avg_rank = sum(ranks) / total
avg_score = sum(scores) / total
rank1_rate = ranks.count(1) / total * 100
rank2_rate = ranks.count(2) / total * 100
rank3_rate = ranks.count(3) / total * 100
last_update = time.strftime('%Y-%m-%d %H:%M:%S')
md_text = f"""
### 📊 对战简报
- ⚔️ **对抗阵容:** 1只 `{TEST_MODEL}` ({main_arch}) **VS** 2只 `{EXAMINER_MODEL}` ({opp_arch})
- 🧮 **总对局数:** {total} 局 (跨节点全局汇集)
- 🏆 **平均顺位:** {avg_rank:.3f}
- 💰 **平均得点:** {avg_score:.0f}
---
- 🥇 **一位率:** {rank1_rate:.1f}%
- 🥈 **二位率:** {rank2_rate:.1f}%
- 🥉 **三位率:** {rank3_rate:.1f}%
---
- 🌐 **当前节点 ID:** `{WORKER_ID}`
- 🕒 **刷新时间:** {last_update}
"""
fig = plt.figure(figsize=(10, 4))
ax1 = fig.add_subplot(121)
ax1.bar(['1st', '2nd', '3rd'], [rank1_rate, rank2_rate, rank3_rate], color=['#FFD700', '#C0C0C0', '#CD7F32'])
ax1.set_title(f'Rank Distribution for {TEST_MODEL}')
ax1.set_ylim(0, max(100, max([rank1_rate, rank2_rate, rank3_rate] + [0]) + 10))
for i, v in enumerate([rank1_rate, rank2_rate, rank3_rate]):
ax1.text(i, v + 2, f"{v:.1f}%", ha='center')
ax2 = fig.add_subplot(122)
df = pd.DataFrame({'score': scores})
df['ma'] = df['score'].rolling(window=min(10, max(1, len(df))), min_periods=1).mean()
ax2.plot(df['score'], alpha=0.3, color='gray', label='Raw Score')
ax2.plot(df['ma'], color='crimson', linewidth=2, label='Moving Avg (10)')
ax2.set_title('Score Trend')
ax2.legend()
plt.tight_layout()
return md_text, fig
except Exception as e:
return f"❌ 数据解析出错: {e}", None
# ==========================================
# 5. 启动 Gradio 应用
# ==========================================
with gr.Blocks() as demo:
gr.Markdown("# 🀄 Mahjong AI 基准评估舱")
header_main = "NEW架构" if ONE_NEW_VS_TWO_LOCAL else "LOCAL架构"
header_opp = "LOCAL架构" if ONE_NEW_VS_TWO_LOCAL else "NEW架构"
gr.Markdown(f"当前正在评估: 1名 **{TEST_MODEL} ({header_main})** 单挑 2名 **{EXAMINER_MODEL} ({header_opp})**。启动时会自动拉取权重。")
with gr.Row():
with gr.Column(scale=1):
stats_output = gr.Markdown("🚀 正在初始化基准环境并连接模型仓库...")
refresh_btn = gr.Button("🔄 手动刷新全局战绩")
with gr.Column(scale=2):
plot_output = gr.Plot()
demo.load(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
timer = gr.Timer(15)
timer.tick(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
refresh_btn.click(fn=read_and_analyze, inputs=None, outputs=[stats_output, plot_output])
if __name__ == "__main__":
t = threading.Thread(target=background_eval_loop, daemon=True)
t.start()
demo.queue().launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())