AudioLabelingApp / annotation.py
sunnyzjx's picture
Update annotation.py
adf842d verified
import re
import os
import json
import time
import atexit
from datetime import datetime, timedelta
from huggingface_hub import HfApi, hf_hub_download
from collections import defaultdict
import config
HF_TOKEN = os.getenv("HF_TOKEN")
REPO_ID = config.SAVE_REPO_ID
api = HfApi()
class HFSpacesAnnotationManager:
def __init__(self):
# 内存缓存,重启后会丢失
self.memory_cache = {}
self.last_upload_times = {}
self.session_start = datetime.now()
# HF Spaces优化配置
self.upload_frequency = 5 # 每5次标注上传一次
self.time_threshold = 300 # 5分钟强制上传
self.max_cache_users = 20 # 最多缓存20个用户
# 注册退出时保存
atexit.register(self.save_all_on_exit)
print(f"📍 运行在 HuggingFace Spaces")
print(f"📋 缓存策略: 每{self.upload_frequency}次标注或{self.time_threshold}秒上传一次")
def save_annotations(self, username, annotation_results, tasks):
"""保存标注结果到HF Spaces"""
try:
# 1. 更新内存缓存
self.update_memory_cache(username, annotation_results, tasks)
# 2. 检查是否需要上传
should_upload, reason = self.should_upload_now(username, annotation_results)
if should_upload:
try:
# 执行上传
upload_result = self.upload_to_hf(username, annotation_results, tasks)
self.last_upload_times[username] = datetime.now()
return f"✅ 内存已保存 + 已上传HF ({reason})\n{upload_result}"
except Exception as e:
return f"✅ 内存已保存,上传失败 ({reason}): {str(e)}"
else:
completed = len(annotation_results)
total = len(tasks)
next_upload = self.upload_frequency - (completed % self.upload_frequency)
return f"✅ 已保存到内存 ({completed}/{total}),还需{next_upload}次标注触发上传"
except Exception as e:
return f"❌ 保存失败: {str(e)}"
def update_memory_cache(self, username, annotation_results, tasks):
"""更新内存缓存"""
# 清理过期缓存(防止内存溢出)
self.cleanup_memory_cache()
self.memory_cache[username] = {
'annotation_results': annotation_results,
'total_tasks': len(tasks),
'completed_tasks': len(annotation_results),
'last_updated': datetime.now(),
'tasks_data': tasks # 保存tasks引用,用于上传
}
def cleanup_memory_cache(self):
"""清理内存缓存"""
# 如果缓存用户过多,清理最旧的
if len(self.memory_cache) > self.max_cache_users:
# 按最后更新时间排序,删除最旧的
sorted_users = sorted(
self.memory_cache.items(),
key=lambda x: x[1]['last_updated']
)
# 删除最旧的用户缓存
oldest_users = sorted_users[:len(self.memory_cache) - self.max_cache_users + 1]
for username, _ in oldest_users:
# 在删除前尝试上传
try:
cache_data = self.memory_cache[username]
self.upload_to_hf(username, cache_data['annotation_results'], cache_data['tasks_data'])
print(f"🗑️ 清理缓存时已上传用户 {username} 的数据")
except Exception as e:
print(f"⚠️ 清理缓存时上传失败 {username}: {e}")
del self.memory_cache[username]
def should_upload_now(self, username, annotation_results):
"""判断是否应该立即上传"""
completed_count = len(annotation_results)
current_time = datetime.now()
last_upload = self.last_upload_times.get(username, self.session_start)
# 条件1: 达到上传频率
if completed_count > 0 and completed_count % self.upload_frequency == 0:
return True, f"完成{self.upload_frequency}次标注"
# 条件2: 超过时间阈值
if (current_time - last_upload).total_seconds() > self.time_threshold:
return True, f"超过{self.time_threshold}秒"
# 条件3: 首次标注
if completed_count == 1 and username not in self.last_upload_times:
return True, "首次标注"
return False, "等待条件触发"
def upload_to_hf(self, username, annotation_results, tasks):
"""上传到HuggingFace"""
save_data = self.prepare_save_data(username, annotation_results, tasks)
save_str = json.dumps(save_data, ensure_ascii=False, indent=2)
filename = get_user_annotation_filename(username)
api.upload_file(
path_or_fileobj=save_str.encode("utf-8"),
path_in_repo=filename,
repo_id=REPO_ID,
repo_type="dataset",
token=HF_TOKEN
)
return f"上传成功: {len(annotation_results)}/{len(tasks)} 项标注"
def prepare_save_data(self, username, annotation_results, tasks):
"""准备保存数据"""
save_data = {
"total_tasks": len(tasks),
"completed_tasks": len(annotation_results),
"username": username,
"last_updated": datetime.now().isoformat(),
"environment": "HuggingFace Spaces",
"annotations": []
}
for task_id, choice in annotation_results.items():
if task_id < len(tasks): # 安全检查
task = tasks[task_id]
save_data["annotations"].append({
"task_id": task_id,
"text": task["text"],
"instruction": task["instruction"],
"comparison": f"{task['audioA_source']} vs {task['audioB_source']}",
"audioA_source": task["audioA_source"],
"audioB_source": task["audioB_source"],
"original_index": task["original_index"],
"choice": choice,
"username": username,
"timestamp": datetime.now().isoformat()
})
return save_data
def load_annotations(self, username):
"""加载标注结果 - 优先内存缓存"""
# 1. 先检查内存缓存
if username in self.memory_cache:
cache_data = self.memory_cache[username]
print(f"📋 从内存缓存加载用户 {username} 的标注")
return cache_data['annotation_results']
# 2. 从HuggingFace加载
try:
filename = get_user_annotation_filename(username)
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
repo_type="dataset",
token=HF_TOKEN,
force_download=True
)
with open(local_path, "r", encoding="utf-8") as f:
save_data = json.load(f)
annotation_results = {ann["task_id"]: ann["choice"] for ann in save_data.get("annotations", [])}
print(f"📥 从HuggingFace加载用户 {username} 的标注")
return annotation_results
except Exception as e:
print(f"⚠️ 加载用户 {username} 标注失败: {e}")
return {}
def save_all_on_exit(self):
"""应用退出时保存所有缓存数据"""
if not self.memory_cache:
return
print(f"🔄 应用即将关闭,正在保存 {len(self.memory_cache)} 个用户的缓存数据...")
success_count = 0
for username, cache_data in self.memory_cache.items():
try:
self.upload_to_hf(
username,
cache_data['annotation_results'],
cache_data['tasks_data']
)
success_count += 1
print(f"✅ 已保存用户 {username} 的数据")
except Exception as e:
print(f"❌ 保存用户 {username} 数据失败: {e}")
print(f"🎯 退出保存完成: {success_count}/{len(self.memory_cache)} 成功")
def get_cache_stats(self):
"""获取缓存统计信息"""
total_annotations = sum(
len(cache['annotation_results'])
for cache in self.memory_cache.values()
)
return {
"cached_users": len(self.memory_cache),
"total_cached_annotations": total_annotations,
"session_duration": str(datetime.now() - self.session_start),
"environment": "HuggingFace Spaces"
}
# 全局实例
annotation_manager = HFSpacesAnnotationManager()
def get_user_annotation_filename(username: str) -> str:
"""生成用户标注文件名"""
safe_username = re.sub(r'[\\/*?:"<>|]', "_", username)
return f"annotation_results_{safe_username}.json"
def save_annotations(username_state, annotation_results_state, tasks):
"""保存标注结果 - 入口函数"""
return annotation_manager.save_annotations(username_state, annotation_results_state, tasks)
def load_annotations(username):
"""加载用户标注 - 入口函数"""
return annotation_manager.load_annotations(username)
def force_upload_all():
"""强制上传所有缓存数据 - 管理员功能"""
annotation_manager.save_all_on_exit()
return "强制上传完成"
def get_cache_stats():
"""获取缓存统计 - 调试功能"""
return annotation_manager.get_cache_stats()
def get_aggregated_filename() -> str:
"""聚合文件名"""
return "aggregated_annotations.json"
def update_aggregated_annotations(tasks):
"""更新聚合标注结果"""
try:
all_annotations = collect_all_annotations()
aggregated_data = build_aggregated_results(all_annotations, tasks)
save_str = json.dumps(aggregated_data, ensure_ascii=False, indent=2)
filename = get_aggregated_filename()
api.upload_file(
path_or_fileobj=save_str.encode("utf-8"),
path_in_repo=filename,
repo_id=REPO_ID,
repo_type="dataset",
token=HF_TOKEN
)
return f"✅ 聚合结果已更新: {filename}"
except Exception as e:
return f"❌ 聚合结果更新失败: {str(e)}"
def collect_all_annotations():
"""收集所有用户的标注结果"""
try:
files_info = api.list_repo_files(repo_id=REPO_ID, repo_type="dataset", token=HF_TOKEN)
all_annotations = []
for filename in files_info:
if filename.startswith("annotation_results_") and filename.endswith(".json"):
try:
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
repo_type="dataset",
token=HF_TOKEN,
force_download=True
)
with open(local_path, "r", encoding="utf-8") as f:
user_data = json.load(f)
all_annotations.extend(user_data.get("annotations", []))
except Exception as e:
print(f"加载文件 {filename} 失败: {e}")
continue
return all_annotations
except Exception as e:
print(f"收集标注失败: {e}")
return []
def build_aggregated_results(all_annotations, tasks):
"""构建聚合结果"""
groups = defaultdict(lambda: {
"text": "",
"instruction": "",
"comparisons": defaultdict(lambda: {"win": 0, "tie": 0, "lose": 0, "annotators": []})
})
for ann in all_annotations:
original_index = ann.get("original_index")
comparison = ann.get("comparison")
choice = ann.get("choice")
username = ann.get("username")
text = ann.get("text", "")
instruction = ann.get("instruction", "")
if original_index is not None and comparison and choice:
key = original_index
groups[key]["text"] = text
groups[key]["instruction"] = instruction
if choice in ["win", "tie", "lose"]:
groups[key]["comparisons"][comparison][choice] += 1
if username not in groups[key]["comparisons"][comparison]["annotators"]:
groups[key]["comparisons"][comparison]["annotators"].append(username)
aggregated_results = []
for original_index, group_data in groups.items():
result_item = {
"original_index": original_index,
"text": group_data["text"],
"instruction": group_data["instruction"],
"comparisons": {}
}
for comparison, votes in group_data["comparisons"].items():
result_item["comparisons"][comparison] = {
"votes(win tie lose)": [votes["win"], votes["tie"], votes["lose"]],
"total_annotators": len(votes["annotators"]),
"annotators": votes["annotators"]
}
aggregated_results.append(result_item)
aggregated_results.sort(key=lambda x: x["original_index"])
return {
"total_groups": len(aggregated_results),
"total_annotations": len(all_annotations),
"results": aggregated_results
}
def load_aggregated_annotations():
"""加载现有的聚合结果"""
try:
filename = get_aggregated_filename()
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
repo_type="dataset",
token=HF_TOKEN,
force_download=True
)
with open(local_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {"total_groups": 0, "total_annotations": 0, "results": []}
def get_aggregated_stats():
"""获取聚合统计信息"""
try:
aggregated_data = load_aggregated_annotations()
stats = {
"total_groups": aggregated_data.get("total_groups", 0),
"total_annotations": aggregated_data.get("total_annotations", 0),
"comparison_summary": {}
}
for result in aggregated_data.get("results", []):
for comparison, data in result.get("comparisons", {}).items():
if comparison not in stats["comparison_summary"]:
stats["comparison_summary"][comparison] = {
"total_votes": 0,
"win": 0, "tie": 0, "lose": 0
}
votes = data.get("votes", [0, 0, 0])
stats["comparison_summary"][comparison]["win"] += votes[0]
stats["comparison_summary"][comparison]["tie"] += votes[1]
stats["comparison_summary"][comparison]["lose"] += votes[2]
stats["comparison_summary"][comparison]["total_votes"] += sum(votes)
return stats
except Exception as e:
return {"error": str(e)}