import re import os import json import time import atexit from datetime import datetime, timedelta from huggingface_hub import HfApi, hf_hub_download from collections import defaultdict import config HF_TOKEN = os.getenv("HF_TOKEN") REPO_ID = config.SAVE_REPO_ID api = HfApi() class HFSpacesAnnotationManager: def __init__(self): # 内存缓存,重启后会丢失 self.memory_cache = {} self.last_upload_times = {} self.session_start = datetime.now() # HF Spaces优化配置 self.upload_frequency = 5 # 每5次标注上传一次 self.time_threshold = 300 # 5分钟强制上传 self.max_cache_users = 20 # 最多缓存20个用户 # 注册退出时保存 atexit.register(self.save_all_on_exit) print(f"📍 运行在 HuggingFace Spaces") print(f"📋 缓存策略: 每{self.upload_frequency}次标注或{self.time_threshold}秒上传一次") def save_annotations(self, username, annotation_results, tasks): """保存标注结果到HF Spaces""" try: # 1. 更新内存缓存 self.update_memory_cache(username, annotation_results, tasks) # 2. 检查是否需要上传 should_upload, reason = self.should_upload_now(username, annotation_results) if should_upload: try: # 执行上传 upload_result = self.upload_to_hf(username, annotation_results, tasks) self.last_upload_times[username] = datetime.now() return f"✅ 内存已保存 + 已上传HF ({reason})\n{upload_result}" except Exception as e: return f"✅ 内存已保存,上传失败 ({reason}): {str(e)}" else: completed = len(annotation_results) total = len(tasks) next_upload = self.upload_frequency - (completed % self.upload_frequency) return f"✅ 已保存到内存 ({completed}/{total}),还需{next_upload}次标注触发上传" except Exception as e: return f"❌ 保存失败: {str(e)}" def update_memory_cache(self, username, annotation_results, tasks): """更新内存缓存""" # 清理过期缓存(防止内存溢出) self.cleanup_memory_cache() self.memory_cache[username] = { 'annotation_results': annotation_results, 'total_tasks': len(tasks), 'completed_tasks': len(annotation_results), 'last_updated': datetime.now(), 'tasks_data': tasks # 保存tasks引用,用于上传 } def cleanup_memory_cache(self): """清理内存缓存""" # 如果缓存用户过多,清理最旧的 if len(self.memory_cache) > self.max_cache_users: # 按最后更新时间排序,删除最旧的 sorted_users = sorted( self.memory_cache.items(), key=lambda x: x[1]['last_updated'] ) # 删除最旧的用户缓存 oldest_users = sorted_users[:len(self.memory_cache) - self.max_cache_users + 1] for username, _ in oldest_users: # 在删除前尝试上传 try: cache_data = self.memory_cache[username] self.upload_to_hf(username, cache_data['annotation_results'], cache_data['tasks_data']) print(f"🗑️ 清理缓存时已上传用户 {username} 的数据") except Exception as e: print(f"⚠️ 清理缓存时上传失败 {username}: {e}") del self.memory_cache[username] def should_upload_now(self, username, annotation_results): """判断是否应该立即上传""" completed_count = len(annotation_results) current_time = datetime.now() last_upload = self.last_upload_times.get(username, self.session_start) # 条件1: 达到上传频率 if completed_count > 0 and completed_count % self.upload_frequency == 0: return True, f"完成{self.upload_frequency}次标注" # 条件2: 超过时间阈值 if (current_time - last_upload).total_seconds() > self.time_threshold: return True, f"超过{self.time_threshold}秒" # 条件3: 首次标注 if completed_count == 1 and username not in self.last_upload_times: return True, "首次标注" return False, "等待条件触发" def upload_to_hf(self, username, annotation_results, tasks): """上传到HuggingFace""" save_data = self.prepare_save_data(username, annotation_results, tasks) save_str = json.dumps(save_data, ensure_ascii=False, indent=2) filename = get_user_annotation_filename(username) api.upload_file( path_or_fileobj=save_str.encode("utf-8"), path_in_repo=filename, repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN ) return f"上传成功: {len(annotation_results)}/{len(tasks)} 项标注" def prepare_save_data(self, username, annotation_results, tasks): """准备保存数据""" save_data = { "total_tasks": len(tasks), "completed_tasks": len(annotation_results), "username": username, "last_updated": datetime.now().isoformat(), "environment": "HuggingFace Spaces", "annotations": [] } for task_id, choice in annotation_results.items(): if task_id < len(tasks): # 安全检查 task = tasks[task_id] save_data["annotations"].append({ "task_id": task_id, "text": task["text"], "instruction": task["instruction"], "comparison": f"{task['audioA_source']} vs {task['audioB_source']}", "audioA_source": task["audioA_source"], "audioB_source": task["audioB_source"], "original_index": task["original_index"], "choice": choice, "username": username, "timestamp": datetime.now().isoformat() }) return save_data def load_annotations(self, username): """加载标注结果 - 优先内存缓存""" # 1. 先检查内存缓存 if username in self.memory_cache: cache_data = self.memory_cache[username] print(f"📋 从内存缓存加载用户 {username} 的标注") return cache_data['annotation_results'] # 2. 从HuggingFace加载 try: filename = get_user_annotation_filename(username) local_path = hf_hub_download( repo_id=REPO_ID, filename=filename, repo_type="dataset", token=HF_TOKEN, force_download=True ) with open(local_path, "r", encoding="utf-8") as f: save_data = json.load(f) annotation_results = {ann["task_id"]: ann["choice"] for ann in save_data.get("annotations", [])} print(f"📥 从HuggingFace加载用户 {username} 的标注") return annotation_results except Exception as e: print(f"⚠️ 加载用户 {username} 标注失败: {e}") return {} def save_all_on_exit(self): """应用退出时保存所有缓存数据""" if not self.memory_cache: return print(f"🔄 应用即将关闭,正在保存 {len(self.memory_cache)} 个用户的缓存数据...") success_count = 0 for username, cache_data in self.memory_cache.items(): try: self.upload_to_hf( username, cache_data['annotation_results'], cache_data['tasks_data'] ) success_count += 1 print(f"✅ 已保存用户 {username} 的数据") except Exception as e: print(f"❌ 保存用户 {username} 数据失败: {e}") print(f"🎯 退出保存完成: {success_count}/{len(self.memory_cache)} 成功") def get_cache_stats(self): """获取缓存统计信息""" total_annotations = sum( len(cache['annotation_results']) for cache in self.memory_cache.values() ) return { "cached_users": len(self.memory_cache), "total_cached_annotations": total_annotations, "session_duration": str(datetime.now() - self.session_start), "environment": "HuggingFace Spaces" } # 全局实例 annotation_manager = HFSpacesAnnotationManager() def get_user_annotation_filename(username: str) -> str: """生成用户标注文件名""" safe_username = re.sub(r'[\\/*?:"<>|]', "_", username) return f"annotation_results_{safe_username}.json" def save_annotations(username_state, annotation_results_state, tasks): """保存标注结果 - 入口函数""" return annotation_manager.save_annotations(username_state, annotation_results_state, tasks) def load_annotations(username): """加载用户标注 - 入口函数""" return annotation_manager.load_annotations(username) def force_upload_all(): """强制上传所有缓存数据 - 管理员功能""" annotation_manager.save_all_on_exit() return "强制上传完成" def get_cache_stats(): """获取缓存统计 - 调试功能""" return annotation_manager.get_cache_stats() def get_aggregated_filename() -> str: """聚合文件名""" return "aggregated_annotations.json" def update_aggregated_annotations(tasks): """更新聚合标注结果""" try: all_annotations = collect_all_annotations() aggregated_data = build_aggregated_results(all_annotations, tasks) save_str = json.dumps(aggregated_data, ensure_ascii=False, indent=2) filename = get_aggregated_filename() api.upload_file( path_or_fileobj=save_str.encode("utf-8"), path_in_repo=filename, repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN ) return f"✅ 聚合结果已更新: {filename}" except Exception as e: return f"❌ 聚合结果更新失败: {str(e)}" def collect_all_annotations(): """收集所有用户的标注结果""" try: files_info = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) all_annotations = [] for filename in files_info: if filename.startswith("annotation_results_") and filename.endswith(".json"): try: local_path = hf_hub_download( repo_id=REPO_ID, filename=filename, repo_type="dataset", token=HF_TOKEN, force_download=True ) with open(local_path, "r", encoding="utf-8") as f: user_data = json.load(f) all_annotations.extend(user_data.get("annotations", [])) except Exception as e: print(f"加载文件 {filename} 失败: {e}") continue return all_annotations except Exception as e: print(f"收集标注失败: {e}") return [] def build_aggregated_results(all_annotations, tasks): """构建聚合结果""" groups = defaultdict(lambda: { "text": "", "instruction": "", "comparisons": defaultdict(lambda: {"win": 0, "tie": 0, "lose": 0, "annotators": []}) }) for ann in all_annotations: original_index = ann.get("original_index") comparison = ann.get("comparison") choice = ann.get("choice") username = ann.get("username") text = ann.get("text", "") instruction = ann.get("instruction", "") if original_index is not None and comparison and choice: key = original_index groups[key]["text"] = text groups[key]["instruction"] = instruction if choice in ["win", "tie", "lose"]: groups[key]["comparisons"][comparison][choice] += 1 if username not in groups[key]["comparisons"][comparison]["annotators"]: groups[key]["comparisons"][comparison]["annotators"].append(username) aggregated_results = [] for original_index, group_data in groups.items(): result_item = { "original_index": original_index, "text": group_data["text"], "instruction": group_data["instruction"], "comparisons": {} } for comparison, votes in group_data["comparisons"].items(): result_item["comparisons"][comparison] = { "votes(win tie lose)": [votes["win"], votes["tie"], votes["lose"]], "total_annotators": len(votes["annotators"]), "annotators": votes["annotators"] } aggregated_results.append(result_item) aggregated_results.sort(key=lambda x: x["original_index"]) return { "total_groups": len(aggregated_results), "total_annotations": len(all_annotations), "results": aggregated_results } def load_aggregated_annotations(): """加载现有的聚合结果""" try: filename = get_aggregated_filename() local_path = hf_hub_download( repo_id=REPO_ID, filename=filename, repo_type="dataset", token=HF_TOKEN, force_download=True ) with open(local_path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {"total_groups": 0, "total_annotations": 0, "results": []} def get_aggregated_stats(): """获取聚合统计信息""" try: aggregated_data = load_aggregated_annotations() stats = { "total_groups": aggregated_data.get("total_groups", 0), "total_annotations": aggregated_data.get("total_annotations", 0), "comparison_summary": {} } for result in aggregated_data.get("results", []): for comparison, data in result.get("comparisons", {}).items(): if comparison not in stats["comparison_summary"]: stats["comparison_summary"][comparison] = { "total_votes": 0, "win": 0, "tie": 0, "lose": 0 } votes = data.get("votes", [0, 0, 0]) stats["comparison_summary"][comparison]["win"] += votes[0] stats["comparison_summary"][comparison]["tie"] += votes[1] stats["comparison_summary"][comparison]["lose"] += votes[2] stats["comparison_summary"][comparison]["total_votes"] += sum(votes) return stats except Exception as e: return {"error": str(e)}