Spaces:
Sleeping
Sleeping
| import re | |
| import os | |
| import json | |
| import time | |
| import atexit | |
| from datetime import datetime, timedelta | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from collections import defaultdict | |
| import config | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| REPO_ID = config.SAVE_REPO_ID | |
| api = HfApi() | |
| class HFSpacesAnnotationManager: | |
| def __init__(self): | |
| # 内存缓存,重启后会丢失 | |
| self.memory_cache = {} | |
| self.last_upload_times = {} | |
| self.session_start = datetime.now() | |
| # HF Spaces优化配置 | |
| self.upload_frequency = 5 # 每5次标注上传一次 | |
| self.time_threshold = 300 # 5分钟强制上传 | |
| self.max_cache_users = 20 # 最多缓存20个用户 | |
| # 注册退出时保存 | |
| atexit.register(self.save_all_on_exit) | |
| print(f"📍 运行在 HuggingFace Spaces") | |
| print(f"📋 缓存策略: 每{self.upload_frequency}次标注或{self.time_threshold}秒上传一次") | |
| def save_annotations(self, username, annotation_results, tasks): | |
| """保存标注结果到HF Spaces""" | |
| try: | |
| # 1. 更新内存缓存 | |
| self.update_memory_cache(username, annotation_results, tasks) | |
| # 2. 检查是否需要上传 | |
| should_upload, reason = self.should_upload_now(username, annotation_results) | |
| if should_upload: | |
| try: | |
| # 执行上传 | |
| upload_result = self.upload_to_hf(username, annotation_results, tasks) | |
| self.last_upload_times[username] = datetime.now() | |
| return f"✅ 内存已保存 + 已上传HF ({reason})\n{upload_result}" | |
| except Exception as e: | |
| return f"✅ 内存已保存,上传失败 ({reason}): {str(e)}" | |
| else: | |
| completed = len(annotation_results) | |
| total = len(tasks) | |
| next_upload = self.upload_frequency - (completed % self.upload_frequency) | |
| return f"✅ 已保存到内存 ({completed}/{total}),还需{next_upload}次标注触发上传" | |
| except Exception as e: | |
| return f"❌ 保存失败: {str(e)}" | |
| def update_memory_cache(self, username, annotation_results, tasks): | |
| """更新内存缓存""" | |
| # 清理过期缓存(防止内存溢出) | |
| self.cleanup_memory_cache() | |
| self.memory_cache[username] = { | |
| 'annotation_results': annotation_results, | |
| 'total_tasks': len(tasks), | |
| 'completed_tasks': len(annotation_results), | |
| 'last_updated': datetime.now(), | |
| 'tasks_data': tasks # 保存tasks引用,用于上传 | |
| } | |
| def cleanup_memory_cache(self): | |
| """清理内存缓存""" | |
| # 如果缓存用户过多,清理最旧的 | |
| if len(self.memory_cache) > self.max_cache_users: | |
| # 按最后更新时间排序,删除最旧的 | |
| sorted_users = sorted( | |
| self.memory_cache.items(), | |
| key=lambda x: x[1]['last_updated'] | |
| ) | |
| # 删除最旧的用户缓存 | |
| oldest_users = sorted_users[:len(self.memory_cache) - self.max_cache_users + 1] | |
| for username, _ in oldest_users: | |
| # 在删除前尝试上传 | |
| try: | |
| cache_data = self.memory_cache[username] | |
| self.upload_to_hf(username, cache_data['annotation_results'], cache_data['tasks_data']) | |
| print(f"🗑️ 清理缓存时已上传用户 {username} 的数据") | |
| except Exception as e: | |
| print(f"⚠️ 清理缓存时上传失败 {username}: {e}") | |
| del self.memory_cache[username] | |
| def should_upload_now(self, username, annotation_results): | |
| """判断是否应该立即上传""" | |
| completed_count = len(annotation_results) | |
| current_time = datetime.now() | |
| last_upload = self.last_upload_times.get(username, self.session_start) | |
| # 条件1: 达到上传频率 | |
| if completed_count > 0 and completed_count % self.upload_frequency == 0: | |
| return True, f"完成{self.upload_frequency}次标注" | |
| # 条件2: 超过时间阈值 | |
| if (current_time - last_upload).total_seconds() > self.time_threshold: | |
| return True, f"超过{self.time_threshold}秒" | |
| # 条件3: 首次标注 | |
| if completed_count == 1 and username not in self.last_upload_times: | |
| return True, "首次标注" | |
| return False, "等待条件触发" | |
| def upload_to_hf(self, username, annotation_results, tasks): | |
| """上传到HuggingFace""" | |
| save_data = self.prepare_save_data(username, annotation_results, tasks) | |
| save_str = json.dumps(save_data, ensure_ascii=False, indent=2) | |
| filename = get_user_annotation_filename(username) | |
| api.upload_file( | |
| path_or_fileobj=save_str.encode("utf-8"), | |
| path_in_repo=filename, | |
| repo_id=REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| return f"上传成功: {len(annotation_results)}/{len(tasks)} 项标注" | |
| def prepare_save_data(self, username, annotation_results, tasks): | |
| """准备保存数据""" | |
| save_data = { | |
| "total_tasks": len(tasks), | |
| "completed_tasks": len(annotation_results), | |
| "username": username, | |
| "last_updated": datetime.now().isoformat(), | |
| "environment": "HuggingFace Spaces", | |
| "annotations": [] | |
| } | |
| for task_id, choice in annotation_results.items(): | |
| if task_id < len(tasks): # 安全检查 | |
| task = tasks[task_id] | |
| save_data["annotations"].append({ | |
| "task_id": task_id, | |
| "text": task["text"], | |
| "instruction": task["instruction"], | |
| "comparison": f"{task['audioA_source']} vs {task['audioB_source']}", | |
| "audioA_source": task["audioA_source"], | |
| "audioB_source": task["audioB_source"], | |
| "original_index": task["original_index"], | |
| "choice": choice, | |
| "username": username, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| return save_data | |
| def load_annotations(self, username): | |
| """加载标注结果 - 优先内存缓存""" | |
| # 1. 先检查内存缓存 | |
| if username in self.memory_cache: | |
| cache_data = self.memory_cache[username] | |
| print(f"📋 从内存缓存加载用户 {username} 的标注") | |
| return cache_data['annotation_results'] | |
| # 2. 从HuggingFace加载 | |
| try: | |
| filename = get_user_annotation_filename(username) | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| force_download=True | |
| ) | |
| with open(local_path, "r", encoding="utf-8") as f: | |
| save_data = json.load(f) | |
| annotation_results = {ann["task_id"]: ann["choice"] for ann in save_data.get("annotations", [])} | |
| print(f"📥 从HuggingFace加载用户 {username} 的标注") | |
| return annotation_results | |
| except Exception as e: | |
| print(f"⚠️ 加载用户 {username} 标注失败: {e}") | |
| return {} | |
| def save_all_on_exit(self): | |
| """应用退出时保存所有缓存数据""" | |
| if not self.memory_cache: | |
| return | |
| print(f"🔄 应用即将关闭,正在保存 {len(self.memory_cache)} 个用户的缓存数据...") | |
| success_count = 0 | |
| for username, cache_data in self.memory_cache.items(): | |
| try: | |
| self.upload_to_hf( | |
| username, | |
| cache_data['annotation_results'], | |
| cache_data['tasks_data'] | |
| ) | |
| success_count += 1 | |
| print(f"✅ 已保存用户 {username} 的数据") | |
| except Exception as e: | |
| print(f"❌ 保存用户 {username} 数据失败: {e}") | |
| print(f"🎯 退出保存完成: {success_count}/{len(self.memory_cache)} 成功") | |
| def get_cache_stats(self): | |
| """获取缓存统计信息""" | |
| total_annotations = sum( | |
| len(cache['annotation_results']) | |
| for cache in self.memory_cache.values() | |
| ) | |
| return { | |
| "cached_users": len(self.memory_cache), | |
| "total_cached_annotations": total_annotations, | |
| "session_duration": str(datetime.now() - self.session_start), | |
| "environment": "HuggingFace Spaces" | |
| } | |
| # 全局实例 | |
| annotation_manager = HFSpacesAnnotationManager() | |
| def get_user_annotation_filename(username: str) -> str: | |
| """生成用户标注文件名""" | |
| safe_username = re.sub(r'[\\/*?:"<>|]', "_", username) | |
| return f"annotation_results_{safe_username}.json" | |
| def save_annotations(username_state, annotation_results_state, tasks): | |
| """保存标注结果 - 入口函数""" | |
| return annotation_manager.save_annotations(username_state, annotation_results_state, tasks) | |
| def load_annotations(username): | |
| """加载用户标注 - 入口函数""" | |
| return annotation_manager.load_annotations(username) | |
| def force_upload_all(): | |
| """强制上传所有缓存数据 - 管理员功能""" | |
| annotation_manager.save_all_on_exit() | |
| return "强制上传完成" | |
| def get_cache_stats(): | |
| """获取缓存统计 - 调试功能""" | |
| return annotation_manager.get_cache_stats() | |
| def get_aggregated_filename() -> str: | |
| """聚合文件名""" | |
| return "aggregated_annotations.json" | |
| def update_aggregated_annotations(tasks): | |
| """更新聚合标注结果""" | |
| try: | |
| all_annotations = collect_all_annotations() | |
| aggregated_data = build_aggregated_results(all_annotations, tasks) | |
| save_str = json.dumps(aggregated_data, ensure_ascii=False, indent=2) | |
| filename = get_aggregated_filename() | |
| api.upload_file( | |
| path_or_fileobj=save_str.encode("utf-8"), | |
| path_in_repo=filename, | |
| repo_id=REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| return f"✅ 聚合结果已更新: {filename}" | |
| except Exception as e: | |
| return f"❌ 聚合结果更新失败: {str(e)}" | |
| def collect_all_annotations(): | |
| """收集所有用户的标注结果""" | |
| try: | |
| files_info = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN) | |
| all_annotations = [] | |
| for filename in files_info: | |
| if filename.startswith("annotation_results_") and filename.endswith(".json"): | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| force_download=True | |
| ) | |
| with open(local_path, "r", encoding="utf-8") as f: | |
| user_data = json.load(f) | |
| all_annotations.extend(user_data.get("annotations", [])) | |
| except Exception as e: | |
| print(f"加载文件 {filename} 失败: {e}") | |
| continue | |
| return all_annotations | |
| except Exception as e: | |
| print(f"收集标注失败: {e}") | |
| return [] | |
| def build_aggregated_results(all_annotations, tasks): | |
| """构建聚合结果""" | |
| groups = defaultdict(lambda: { | |
| "text": "", | |
| "instruction": "", | |
| "comparisons": defaultdict(lambda: {"win": 0, "tie": 0, "lose": 0, "annotators": []}) | |
| }) | |
| for ann in all_annotations: | |
| original_index = ann.get("original_index") | |
| comparison = ann.get("comparison") | |
| choice = ann.get("choice") | |
| username = ann.get("username") | |
| text = ann.get("text", "") | |
| instruction = ann.get("instruction", "") | |
| if original_index is not None and comparison and choice: | |
| key = original_index | |
| groups[key]["text"] = text | |
| groups[key]["instruction"] = instruction | |
| if choice in ["win", "tie", "lose"]: | |
| groups[key]["comparisons"][comparison][choice] += 1 | |
| if username not in groups[key]["comparisons"][comparison]["annotators"]: | |
| groups[key]["comparisons"][comparison]["annotators"].append(username) | |
| aggregated_results = [] | |
| for original_index, group_data in groups.items(): | |
| result_item = { | |
| "original_index": original_index, | |
| "text": group_data["text"], | |
| "instruction": group_data["instruction"], | |
| "comparisons": {} | |
| } | |
| for comparison, votes in group_data["comparisons"].items(): | |
| result_item["comparisons"][comparison] = { | |
| "votes(win tie lose)": [votes["win"], votes["tie"], votes["lose"]], | |
| "total_annotators": len(votes["annotators"]), | |
| "annotators": votes["annotators"] | |
| } | |
| aggregated_results.append(result_item) | |
| aggregated_results.sort(key=lambda x: x["original_index"]) | |
| return { | |
| "total_groups": len(aggregated_results), | |
| "total_annotations": len(all_annotations), | |
| "results": aggregated_results | |
| } | |
| def load_aggregated_annotations(): | |
| """加载现有的聚合结果""" | |
| try: | |
| filename = get_aggregated_filename() | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| force_download=True | |
| ) | |
| with open(local_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {"total_groups": 0, "total_annotations": 0, "results": []} | |
| def get_aggregated_stats(): | |
| """获取聚合统计信息""" | |
| try: | |
| aggregated_data = load_aggregated_annotations() | |
| stats = { | |
| "total_groups": aggregated_data.get("total_groups", 0), | |
| "total_annotations": aggregated_data.get("total_annotations", 0), | |
| "comparison_summary": {} | |
| } | |
| for result in aggregated_data.get("results", []): | |
| for comparison, data in result.get("comparisons", {}).items(): | |
| if comparison not in stats["comparison_summary"]: | |
| stats["comparison_summary"][comparison] = { | |
| "total_votes": 0, | |
| "win": 0, "tie": 0, "lose": 0 | |
| } | |
| votes = data.get("votes", [0, 0, 0]) | |
| stats["comparison_summary"][comparison]["win"] += votes[0] | |
| stats["comparison_summary"][comparison]["tie"] += votes[1] | |
| stats["comparison_summary"][comparison]["lose"] += votes[2] | |
| stats["comparison_summary"][comparison]["total_votes"] += sum(votes) | |
| return stats | |
| except Exception as e: | |
| return {"error": str(e)} |