Spaces:
Paused
Paused
| import os | |
| import re | |
| import json | |
| import pickle | |
| import threading | |
| import traceback | |
| import requests | |
| import numpy as np | |
| from typing import * | |
| from datetime import datetime | |
| # Web 服务与 HF Hub 依赖 | |
| from fastapi import FastAPI | |
| import uvicorn | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # 底层特征引擎 (Teacher) | |
| from libriichi3p.mjai import Bot as RiichiBot | |
| from libriichi3p.consts import ACTION_SPACE | |
| # 底层特征引擎 (Student) | |
| try: | |
| from libriichiSanma import state as sanma_state | |
| except ImportError: | |
| import libriichi as sanma_state | |
| # ========================================== | |
| # [配置与环境变量] | |
| # ========================================== | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| DATASET_REPO = os.environ.get("DATASET_REPO", "AstraNASA/tenhou-scc") | |
| URL_LIST_FILE = os.environ.get("URL_LIST_FILE", "urls_3.txt") | |
| MASK_3P = [ | |
| "1m", "2m", "3m", "4m", "5m", "6m", "7m", "8m", "9m", | |
| "1p", "2p", "3p", "4p", "5p", "6p", "7p", "8p", "9p", | |
| "1s", "2s", "3s", "4s", "5s", "6s", "7s", "8s", "9s", | |
| "E", "S", "W", "N", "P", "F", "C", | |
| '5mr', '5pr', '5sr', | |
| 'reach', 'pon', 'kan', 'nukidora', 'hora', 'ryukyoku', 'none' | |
| ] | |
| NONE_CODE = MASK_3P.index('none') | |
| KAN_CODE = MASK_3P.index('kan') | |
| _thread_local = threading.local() | |
| worker_status = { | |
| "status": "Starting up...", | |
| "urls_processed": 0, | |
| "total_chunks_uploaded": 0, | |
| "total_records_extracted": 0, | |
| "current_target": "", | |
| "errors": 0 | |
| } | |
| # ========================================== | |
| # [解析器] 保持不变 | |
| # ========================================== | |
| class TenhouParser: | |
| def tile_name(x): | |
| if x in (51, 52, 53): return ['5mr', '5pr', '5sr'][x - 51] | |
| num, suit = x % 10, x // 10 | |
| if suit in (1, 2, 3): return str(num) + 'mps'[suit - 1] | |
| if suit == 4: return 'ESWNPFC'[num - 1] | |
| return '?' | |
| def get_meld_tiles(cls, actor, s): | |
| i, player = 0, 0 | |
| result = {'pai': [], 'consumed': [], 'actor': actor} | |
| while i < len(s): | |
| player += 1 | |
| tile_type = 'consumed' | |
| if s[i] in 'cpmakf': | |
| tile_type = 'pai' | |
| result['type'] = ['chi', 'pon', 'daiminkan', 'ankan', 'kakan', 'nukidora']['cpmakf'.index(s[i])] | |
| if s[i] in 'cpm': | |
| result['target'] = (4 - player + actor) % 4 | |
| i += 1 | |
| result[tile_type].append(cls.tile_name(int(s[i:i+2]))) | |
| i += 2 | |
| result['pai'] = result['pai'][0] | |
| if result.get('type') == 'ankan': result['consumed'].append(result['pai']) | |
| return result | |
| def parse_events(cls, actor, income, outcome): | |
| incoming, outcoming = [], [] | |
| for i, event in enumerate(income): | |
| if type(event) is str: incoming.append(cls.get_meld_tiles(actor, event)) | |
| else: incoming.append({'type': 'tsumo', 'pai': cls.tile_name(event), 'actor': actor}) | |
| for i, event in enumerate(outcome): | |
| if type(event) is str and event[0] != 'r': | |
| outcoming.append(cls.get_meld_tiles(actor, event)) | |
| else: | |
| if event == 0: | |
| outcoming.append({'type': 'empty'}) | |
| continue | |
| reach = False | |
| if type(event) is str and event[0] == 'r': | |
| reach, event = True, int(event[1:]) | |
| outcoming.append({'type': 'reach', 'actor': actor}) | |
| outcoming.append({'type': 'dahai', 'pai': cls.tile_name(event if event != 60 else income[i]), 'actor': actor, 'tsumogiri': event == 60}) | |
| if reach: outcoming.append({'type': 'reach_accepted', 'actor': actor}) | |
| return incoming, outcoming | |
| def merge_events(cls, oya, events, dora_markers): | |
| current, result = oya, [] | |
| def finished(x): return all(len(i[0]) == 0 and len(i[1]) == 0 for i in x) | |
| while not finished(events): | |
| income, outcome = events[current] | |
| nuki = False | |
| if len(income): | |
| result.append(income.pop(0)) | |
| if result[-1]['type'] == 'daiminkan': | |
| result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))}) | |
| outcome.pop(0) | |
| continue | |
| if len(outcome): | |
| result.append(outcome.pop(0)) | |
| pai, t = result[-1].get('pai'), result[-1]['type'] | |
| if t == 'reach': | |
| result.append(outcome.pop(0)) | |
| pai = result[-1].get('pai') | |
| result.append(outcome.pop(0)) | |
| nuki = False | |
| for actor, x in enumerate(events): | |
| if actor == current or len(x[1]) == 0: continue | |
| if x[0][0]['type'] != 'tsumo' and x[0][0].get('pai') == pai and not (x[0][0]['type'] == 'chi' and not (x[0][0]['actor'] + 3) % 4 == actor): | |
| nuki, current = True, actor | |
| break | |
| if t in ('ankan', 'kakan', 'nukidora'): | |
| if t != 'nukidora' and len(dora_markers) > 0: result.append({'type': 'dora', 'dora_marker': cls.tile_name(dora_markers.pop(0))}) | |
| nuki = True | |
| if not nuki: current = (current + 1) % 4 | |
| return result | |
| def parse_single_round(cls, data): | |
| round_info, scores, dora_markers, uradora, result_info = data[0], data[1], data[2], data[3], data[-1] | |
| oya = round_info[0] % 4 | |
| patch = lambda arr: arr if len(arr) >= 13 else [0] * 13 | |
| events = [{ | |
| 'type': 'start_kyoku', 'bakaze': 'ESWN'[round_info[0] // 4], 'kyoku': oya + 1, | |
| 'honba': round_info[1], 'kyotaku': round_info[2], 'oya': oya, | |
| 'dora_marker': cls.tile_name(dora_markers.pop(0)), 'scores': scores, | |
| 'tehais': [[cls.tile_name(i) for i in patch(data[k])] for k in [4, 7, 10, 13]] | |
| }] | |
| e_list = [cls.parse_events(i, data[5+i*3], data[6+i*3]) for i in range(4)] | |
| events += cls.merge_events(oya, e_list, dora_markers) | |
| last_type = events[-1]['type'] | |
| if last_type == 'tsumo' and result_info[0] == '和了': events.append({'type': 'hora', 'actor': events[-1]['actor'], 'target': events[-1]['actor']}) | |
| elif result_info[0] == '和了': | |
| actor = next(i for i, x in enumerate(result_info[1]) if x > 0) | |
| events.append({'type': 'hora', 'actor': actor, 'target': actor}) | |
| elif last_type == 'tsumo' or '九牌' in result_info[0]: events.append({'type': 'ryukyoku', 'actor': events[-1]['actor']}) | |
| return events | |
| def parse_log(cls, log): | |
| scores = log.get('sc', []) | |
| weights = [1.0, 1.0, 1.0] | |
| seat = log['name'].index('私') if '私' in log['name'] else -1 | |
| parsed_rounds = [] | |
| for i in log['log'][:]: | |
| round_events = [{"type": "start_game", "id": seat, "weight": weights}] + cls.parse_single_round(i) | |
| parsed_rounds.append(round_events) | |
| return parsed_rounds | |
| # ========================================== | |
| # [特征拦截假引擎 (Teacher)] | |
| # ========================================== | |
| class DummyFeatureEngine: | |
| def __init__(self): | |
| self.engine_type = 'mortal' | |
| self.name = 'DataMiner' | |
| self.version = 4 | |
| self.is_oracle = False | |
| self.enable_quick_eval = True | |
| self.enable_rule_based_agari_guard = True | |
| def react_batch(self, obs, masks, invisible_obs): | |
| _thread_local.interception = (obs, masks, invisible_obs) | |
| batch_size = len(obs) | |
| actions, q_outs, pure_masks = [], [], [] | |
| for m in masks: | |
| m_list = m.tolist() if hasattr(m, 'tolist') else list(m) | |
| pure_masks.append(m_list) | |
| try: valid_action = m_list.index(True) | |
| except ValueError: valid_action = 0 | |
| actions.append(valid_action) | |
| q_outs.append([0.0] * len(m_list)) | |
| return actions, q_outs, pure_masks, [True] * batch_size | |
| # ========================================== | |
| # [双重特征打包架构 (Distillation)] | |
| # ========================================== | |
| class FeatureEncoder: | |
| def __init__(self, chunk_size=2048, pool_size=8): | |
| self.chunk_size = chunk_size | |
| self.pool_size = pool_size | |
| self.inputs, self.outputs, self.weights = [], [], [] | |
| self.chunk_count = 0 | |
| self.hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| self.local_pool_dir = "local_chunks_pool" | |
| os.makedirs(self.local_pool_dir, exist_ok=True) | |
| def action_to_mask(who, action): | |
| if action is None: return NONE_CODE | |
| if type(action) is str: action = json.loads(action) | |
| if action.get('actor') != who or action.get('type') == 'tsumo': return NONE_CODE | |
| if action['type'] == 'dahai': return MASK_3P.index(action['pai']) | |
| if action['type'] in ('daiminkan', 'ankan', 'kakan'): return KAN_CODE | |
| if action['type'] in MASK_3P: return MASK_3P.index(action['type']) | |
| raise Exception(f"Unknown action map: {action}") | |
| def save_and_check_upload(self): | |
| if not self.inputs: return | |
| filename = f"chunk_distill_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{self.chunk_count}.pkl" | |
| filepath = os.path.join(self.local_pool_dir, filename) | |
| with open(filepath, 'wb') as f: | |
| pickle.dump({'inputs': self.inputs, 'outputs': self.outputs, 'weights': self.weights}, f) | |
| print(f"📦 已生成蒸馏缓存: {filename} ({len(self.inputs)} records).") | |
| self.chunk_count += 1 | |
| self.inputs.clear() | |
| self.outputs.clear() | |
| self.weights.clear() | |
| current_files = os.listdir(self.local_pool_dir) | |
| if len(current_files) >= self.pool_size: | |
| self.upload_pool() | |
| def upload_pool(self): | |
| current_files = os.listdir(self.local_pool_dir) | |
| if not current_files or not self.hf_api or not DATASET_REPO: return | |
| import time | |
| print(f"🚀 本地池满,正在批量上传 {len(current_files)} 个文件...") | |
| for attempt in range(6): | |
| try: | |
| self.hf_api.upload_folder( | |
| folder_path=self.local_pool_dir, | |
| path_in_repo="distill_chunks", | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset" | |
| ) | |
| print(f"✅ 上传成功 (Attempt {attempt + 1}).") | |
| worker_status["total_chunks_uploaded"] += len(current_files) | |
| for f in current_files: os.remove(os.path.join(self.local_pool_dir, f)) | |
| break | |
| except Exception as e: | |
| wait_time = 5 * (2 ** attempt) | |
| print(f"⚠️ Upload failed: {e}. Waiting {wait_time}s...") | |
| time.sleep(wait_time) | |
| def process_game(self, events): | |
| who = -1 | |
| current_weight = 1.0 | |
| ps_student = None | |
| bot_teacher = None | |
| for i, event in enumerate(events): | |
| if event.get('type') == 'start_game': | |
| who = event['id'] | |
| weights_list = event.get('weight', [1.0, 1.0, 1.0]) | |
| current_weight = weights_list[who] | |
| # 初始化双模型状态机 | |
| ps_student = sanma_state.PlayerState(who) | |
| bot_teacher = RiichiBot(DummyFeatureEngine(), who) | |
| if ps_student is None or bot_teacher is None: | |
| continue | |
| if event.get('type') == 'end_game': | |
| continue | |
| next_event = None | |
| for j in range(i + 1, len(events)): | |
| if events[j].get('type') not in ('dora', 'reach_accepted'): | |
| next_event = events[j]; break | |
| event_str = json.dumps(event, separators=(",", ":")) | |
| # --- 1. Teacher 更新与拦截 --- | |
| _thread_local.interception = None | |
| bot_teacher.react(event_str) | |
| intercepted = getattr(_thread_local, 'interception', None) | |
| # --- 2. Student 更新与特征生成 --- | |
| cans = ps_student.update(event_str) | |
| if intercepted is None or not cans.can_act: | |
| continue | |
| obs_t, masks_t, _ = intercepted | |
| obs_s, mask_s = ps_student.encode_obs(4, False) | |
| valid_actions_count = int(np.count_nonzero(masks_t[0])) | |
| if valid_actions_count <= 1: | |
| continue | |
| try: | |
| output_code = self.action_to_mask(who, next_event) | |
| # 存入字典,解耦新老数据格式 | |
| self.inputs.append({ | |
| "obs_student": obs_s, | |
| "mask_student": mask_s, | |
| "obs_teacher": obs_t[0], # 去除 batch 维度 | |
| "mask_teacher": masks_t[0] | |
| }) | |
| self.outputs.append(output_code) | |
| self.weights.append(current_weight) | |
| worker_status["total_records_extracted"] += 1 | |
| except Exception: pass | |
| if len(self.inputs) >= self.chunk_size: | |
| self.save_and_check_upload() | |
| # ========================================== | |
| # [数据挖掘总管线] | |
| # ========================================== | |
| def worker_pipeline(): | |
| if not HF_TOKEN or not DATASET_REPO: | |
| worker_status["status"] = "Error: HF_TOKEN or DATASET_REPO missing!" | |
| return | |
| worker_status["status"] = "Fetching target URL list..." | |
| try: | |
| url_file_path = hf_hub_download(repo_id=DATASET_REPO, filename=URL_LIST_FILE, repo_type="dataset", token=HF_TOKEN) | |
| with open(url_file_path, 'r') as f: target_urls = [line.strip() for line in f if line.strip()] | |
| except Exception as e: | |
| worker_status["status"] = f"Failed to fetch {URL_LIST_FILE}: {e}" | |
| return | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| encoder = FeatureEncoder(chunk_size=2048, pool_size=8) | |
| worker_status["status"] = "Mining..." | |
| for url in target_urls: | |
| worker_status["current_target"] = url | |
| log_match = re.search(r'log=([^&]+)', url) | |
| tw_match = re.search(r'tw=(\d+)', url) | |
| if not log_match: continue | |
| tw = int(tw_match.group(1)) if tw_match else -1 | |
| log_id = log_match.group(1) | |
| try: | |
| res = requests.get(f"https://tenhou.net/5/mjlog2json.cgi?{log_id}", headers=headers, timeout=30) | |
| parsed_games = TenhouParser.parse_log(res.json()) | |
| for game in parsed_games: | |
| for j in range(3): | |
| if j == tw: continue | |
| game[0]['id'] = j | |
| encoder.process_game(game) | |
| worker_status["urls_processed"] += 1 | |
| except Exception as e: | |
| worker_status["errors"] += 1 | |
| encoder.save_and_check_upload() | |
| encoder.upload_pool() | |
| worker_status["status"] = "Finished! All URLs processed." | |
| worker_status["current_target"] = "Idle" | |
| app = FastAPI() | |
| def read_status(): return worker_status | |
| if __name__ == '__main__': | |
| thread = threading.Thread(target=worker_pipeline, daemon=True) | |
| thread.start() | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |