DistillEncoder / app.py
ffzeroHua's picture
Update app.py
b31ff05 verified
import os
import re
import json
import pickle
import threading
import traceback
import requests
import numpy as np
from typing import *
from datetime import datetime
# Web 服务与 HF Hub 依赖
from fastapi import FastAPI
import uvicorn
from huggingface_hub import HfApi, hf_hub_download
# 底层特征引擎 (Teacher)
from libriichi3p.mjai import Bot as RiichiBot
from libriichi3p.consts import ACTION_SPACE
# 底层特征引擎 (Student)
try:
from libriichiSanma import state as sanma_state
except ImportError:
import libriichi as sanma_state
# ==========================================
# [配置与环境变量]
# ==========================================
HF_TOKEN = os.environ.get("HF_TOKEN", "")
DATASET_REPO = os.environ.get("DATASET_REPO", "AstraNASA/tenhou-scc")
URL_LIST_FILE = os.environ.get("URL_LIST_FILE", "urls_better.txt")
MASK_3P = [
"1m", "2m", "3m", "4m", "5m", "6m", "7m", "8m", "9m",
"1p", "2p", "3p", "4p", "5p", "6p", "7p", "8p", "9p",
"1s", "2s", "3s", "4s", "5s", "6s", "7s", "8s", "9s",
"E", "S", "W", "N", "P", "F", "C",
'5mr', '5pr', '5sr',
'reach', 'pon', 'kan', 'nukidora', 'hora', 'ryukyoku', 'none'
]
NONE_CODE = MASK_3P.index('none')
KAN_CODE = MASK_3P.index('kan')
_thread_local = threading.local()
worker_status = {
"status": "Starting up...",
"urls_processed": 0,
"total_chunks_uploaded": 0,
"total_records_extracted": 0,
"current_target": "",
"errors": 0
}
# ==========================================
# [解析器] 保持不变
# ==========================================
class TenhouParser:
@staticmethod
def tile_name(x):
if x in (51, 52, 53): return ['5mr', '5pr', '5sr'][x - 51]
num, suit = x % 10, x // 10
if suit in (1, 2, 3): return str(num) + 'mps'[suit - 1]
if suit == 4: return 'ESWNPFC'[num - 1]
return '?'
@classmethod
def get_meld_tiles(cls, actor, s):
i, player = 0, 0
result = {'pai': [], 'consumed': [], 'actor': actor}
while i < len(s):
player += 1
tile_type = 'consumed'
if s[i] in 'cpmakf':
tile_type = 'pai'
result['type'] = ['chi', 'pon', 'daiminkan', 'ankan', 'kakan', 'nukidora']['cpmakf'.index(s[i])]
if s[i] in 'cpm':
result['target'] = (4 - player + actor) % 4
i += 1
result[tile_type].append(cls.tile_name(int(s[i:i+2])))
i += 2
result['pai'] = result['pai'][0]
if result.get('type') == 'ankan': result['consumed'].append(result['pai'])
return result
@classmethod
def parse_events(cls, actor, income, outcome):
incoming, outcoming = [], []
for i, event in enumerate(income):
if type(event) is str: incoming.append(cls.get_meld_tiles(actor, event))
else: incoming.append({'type': 'tsumo', 'pai': cls.tile_name(event), 'actor': actor})
for i, event in enumerate(outcome):
if type(event) is str and event[0] != 'r':
outcoming.append(cls.get_meld_tiles(actor, event))
else:
if event == 0:
outcoming.append({'type': 'empty'})
continue
reach = False
if type(event) is str and event[0] == 'r':
reach, event = True, int(event[1:])
outcoming.append({'type': 'reach', 'actor': actor})
outcoming.append({'type': 'dahai', 'pai': cls.tile_name(event if event != 60 else income[i]), 'actor': actor, 'tsumogiri': event == 60})
if reach: outcoming.append({'type': 'reach_accepted', 'actor': actor})
return incoming, outcoming
@classmethod
def merge_events(cls, oya, events, dora_markers):
current, result = oya, []
def finished(x): return all(len(i[0]) == 0 and len(i[1]) == 0 for i in x)
while not finished(events):
income, outcome = events[current]
nuki = False
if len(income):
result.append(income.pop(0))
if result[-1]['type'] == 'daiminkan':
result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
outcome.pop(0)
continue
if len(outcome):
result.append(outcome.pop(0))
pai, t = result[-1].get('pai'), result[-1]['type']
if t == 'reach':
result.append(outcome.pop(0))
pai = result[-1].get('pai')
result.append(outcome.pop(0))
nuki = False
for actor, x in enumerate(events):
if actor == current or len(x[1]) == 0: continue
if x[0][0]['type'] != 'tsumo' and x[0][0].get('pai') == pai and not (x[0][0]['type'] == 'chi' and not (x[0][0]['actor'] + 3) % 4 == actor):
nuki, current = True, actor
break
if t in ('ankan', 'kakan', 'nukidora'):
if t != 'nukidora' and len(dora_markers) > 0: result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))})
nuki = True
if not nuki: current = (current + 1) % 4
return result
@classmethod
def parse_single_round(cls, data):
round_info, scores, dora_markers, uradora, result_info = data[0], data[1], data[2], data[3], data[-1]
oya = round_info[0] % 4
patch = lambda arr: arr if len(arr) >= 13 else [0] * 13
events = [{
'type': 'start_kyoku', 'bakaze': 'ESWN'[round_info[0] // 4], 'kyoku': oya + 1,
'honba': round_info[1], 'kyotaku': round_info[2], 'oya': oya,
'dora_marker': cls.tile_name(dora_markers.pop(0)), 'scores': scores,
'tehais': [[cls.tile_name(i) for i in patch(data[k])] for k in [4, 7, 10, 13]]
}]
e_list = [cls.parse_events(i, data[5+i*3], data[6+i*3]) for i in range(4)]
events += cls.merge_events(oya, e_list, dora_markers)
last_type = events[-1]['type']
if last_type == 'tsumo' and result_info[0] == '和了': events.append({'type': 'hora', 'actor': events[-1]['actor'], 'target': events[-1]['actor']})
elif result_info[0] == '和了':
actor = next(i for i, x in enumerate(result_info[1]) if x > 0)
events.append({'type': 'hora', 'actor': actor, 'target': actor})
elif last_type == 'tsumo' or '九牌' in result_info[0]: events.append({'type': 'ryukyoku', 'actor': events[-1]['actor']})
return events
@classmethod
def parse_log(cls, log):
scores = log.get('sc', [])
weights = [1.0, 1.0, 1.0]
seat = log['name'].index('私') if '私' in log['name'] else -1
parsed_rounds = []
for i in log['log'][:]:
round_events = [{"type": "start_game", "id": seat, "weight": weights}] + cls.parse_single_round(i)
parsed_rounds.append(round_events)
return parsed_rounds
# ==========================================
# [特征拦截假引擎 (Teacher)]
# ==========================================
class DummyFeatureEngine:
def __init__(self):
self.engine_type = 'mortal'
self.name = 'DataMiner'
self.version = 4
self.is_oracle = False
self.enable_quick_eval = True
self.enable_rule_based_agari_guard = True
def react_batch(self, obs, masks, invisible_obs):
_thread_local.interception = (obs, masks, invisible_obs)
batch_size = len(obs)
actions, q_outs, pure_masks = [], [], []
for m in masks:
m_list = m.tolist() if hasattr(m, 'tolist') else list(m)
pure_masks.append(m_list)
try: valid_action = m_list.index(True)
except ValueError: valid_action = 0
actions.append(valid_action)
q_outs.append([0.0] * len(m_list))
return actions, q_outs, pure_masks, [True] * batch_size
# ==========================================
# [双重特征打包架构 (Distillation)]
# ==========================================
class FeatureEncoder:
def __init__(self, chunk_size=2048, pool_size=8):
self.chunk_size = chunk_size
self.pool_size = pool_size
self.inputs, self.outputs, self.weights = [], [], []
self.chunk_count = 0
self.hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
self.local_pool_dir = "local_chunks_pool"
os.makedirs(self.local_pool_dir, exist_ok=True)
@staticmethod
def action_to_mask(who, action):
if action is None: return NONE_CODE
if type(action) is str: action = json.loads(action)
if action.get('actor') != who or action.get('type') == 'tsumo': return NONE_CODE
if action['type'] == 'dahai': return MASK_3P.index(action['pai'])
if action['type'] in ('daiminkan', 'ankan', 'kakan'): return KAN_CODE
if action['type'] in MASK_3P: return MASK_3P.index(action['type'])
raise Exception(f"Unknown action map: {action}")
def save_and_check_upload(self):
if not self.inputs: return
filename = f"chunk_distill_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.chunk_count}.pkl"
filepath = os.path.join(self.local_pool_dir, filename)
with open(filepath, 'wb') as f:
pickle.dump({'inputs': self.inputs, 'outputs': self.outputs, 'weights': self.weights}, f)
print(f"📦 已生成蒸馏缓存: {filename} ({len(self.inputs)} records).")
self.chunk_count += 1
self.inputs.clear()
self.outputs.clear()
self.weights.clear()
current_files = os.listdir(self.local_pool_dir)
if len(current_files) >= self.pool_size:
self.upload_pool()
def upload_pool(self):
current_files = os.listdir(self.local_pool_dir)
if not current_files or not self.hf_api or not DATASET_REPO: return
import time
print(f"🚀 本地池满,正在批量上传 {len(current_files)} 个文件...")
for attempt in range(6):
try:
self.hf_api.upload_folder(
folder_path=self.local_pool_dir,
path_in_repo="distill_better",
repo_id=DATASET_REPO,
repo_type="dataset"
)
print(f"✅ 上传成功 (Attempt {attempt + 1}).")
worker_status["total_chunks_uploaded"] += len(current_files)
for f in current_files: os.remove(os.path.join(self.local_pool_dir, f))
break
except Exception as e:
wait_time = 5 * (2 ** attempt)
print(f"⚠️ Upload failed: {e}. Waiting {wait_time}s...")
time.sleep(wait_time)
def process_game(self, events):
who = -1
current_weight = 1.0
ps_student = None
bot_teacher = None
for i, event in enumerate(events):
if event.get('type') == 'start_game':
who = event['id']
weights_list = event.get('weight', [1.0, 1.0, 1.0])
current_weight = weights_list[who]
# 初始化双模型状态机
ps_student = sanma_state.PlayerState(who)
bot_teacher = RiichiBot(DummyFeatureEngine(), who)
if ps_student is None or bot_teacher is None:
continue
if event.get('type') == 'end_game':
continue
next_event = None
for j in range(i + 1, len(events)):
if events[j].get('type') not in ('dora', 'reach_accepted'):
next_event = events[j]; break
event_str = json.dumps(event, separators=(",", ":"))
# --- 1. Teacher 更新与拦截 ---
_thread_local.interception = None
bot_teacher.react(event_str)
intercepted = getattr(_thread_local, 'interception', None)
# --- 2. Student 更新与特征生成 ---
cans = ps_student.update(event_str)
if intercepted is None or not cans.can_act:
continue
obs_t, masks_t, _ = intercepted
obs_s, mask_s = ps_student.encode_obs(4, False)
valid_actions_count = int(np.count_nonzero(masks_t[0]))
if valid_actions_count <= 1:
continue
try:
output_code = self.action_to_mask(who, next_event)
# 存入字典,解耦新老数据格式
self.inputs.append({
"obs_student": obs_s,
"mask_student": mask_s,
"obs_teacher": obs_t[0], # 去除 batch 维度
"mask_teacher": masks_t[0]
})
self.outputs.append(output_code)
self.weights.append(current_weight)
worker_status["total_records_extracted"] += 1
except Exception: pass
if len(self.inputs) >= self.chunk_size:
self.save_and_check_upload()
# ==========================================
# [数据挖掘总管线]
# ==========================================
def worker_pipeline():
if not HF_TOKEN or not DATASET_REPO:
worker_status["status"] = "Error: HF_TOKEN or DATASET_REPO missing!"
return
worker_status["status"] = "Fetching target URL list..."
try:
url_file_path = hf_hub_download(repo_id=DATASET_REPO, filename=URL_LIST_FILE, repo_type="dataset", token=HF_TOKEN)
with open(url_file_path, 'r') as f: target_urls = [line.strip() for line in f if line.strip()]
except Exception as e:
worker_status["status"] = f"Failed to fetch {URL_LIST_FILE}: {e}"
return
headers = {"User-Agent": "Mozilla/5.0"}
encoder = FeatureEncoder(chunk_size=2048, pool_size=8)
worker_status["status"] = "Mining..."
for url in target_urls:
worker_status["current_target"] = url
log_match = re.search(r'log=([^&]+)', url)
tw_match = re.search(r'tw=(\d+)', url)
if not log_match: continue
tw = int(tw_match.group(1)) if tw_match else -1
log_id = log_match.group(1)
try:
res = requests.get(f"https://tenhou.net/5/mjlog2json.cgi?{log_id}", headers=headers, timeout=30)
parsed_games = TenhouParser.parse_log(res.json())
for game in parsed_games:
for j in range(3):
if j == tw: continue
game[0]['id'] = j
encoder.process_game(game)
worker_status["urls_processed"] += 1
except Exception as e:
worker_status["errors"] += 1
encoder.save_and_check_upload()
encoder.upload_pool()
worker_status["status"] = "Finished! All URLs processed."
worker_status["current_target"] = "Idle"
app = FastAPI()
@app.get("/")
def read_status(): return worker_status
if __name__ == '__main__':
thread = threading.Thread(target=worker_pipeline, daemon=True)
thread.start()
uvicorn.run(app, host="0.0.0.0", port=7860)