ATLAS / src /submission /submit.py
“pangjh3”
modified: src/about.py
f652754
#!/usr/bin/env python3
"""
ATLAS提交处理 - OSS模式
使用阿里云OSS替代git/http提交方式
"""
import json
import os
import sys
from datetime import datetime
from typing import Dict, Any, Tuple
from pathlib import Path
# 导入OSS提交处理器
try:
from src.oss.oss_submission_handler import OSSSubmissionHandler
OSS_AVAILABLE = True
except ImportError as e:
print(f"⚠️ OSS module not available, using fallback mode: {e}")
OSS_AVAILABLE = False
def format_error(msg):
return f"<p style='color: red; font-size: 16px;'>{msg}</p>"
def format_success(msg):
return f"<p style='color: green; font-size: 16px;'>{msg}</p>"
def format_warning(msg):
return f"<p style='color: orange; font-size: 16px;'>{msg}</p>"
def validate_sage_submission(submission_data: Dict[str, Any]) -> Tuple[bool, str]:
"""验证ATLAS基准提交格式"""
# 检查必需的顶级字段
required_fields = ["submission_org", "submission_email", "predictions"]
for field in required_fields:
if field not in submission_data:
return False, f"Missing required field: {field}"
# 验证邮箱格式(基本验证)
email = submission_data["submission_email"]
if "@" not in email or "." not in email:
return False, "Invalid email format"
# 验证predictions
predictions = submission_data["predictions"]
if not isinstance(predictions, list) or len(predictions) == 0:
return False, "predictions must be a non-empty list"
for i, prediction in enumerate(predictions):
# 检查必需的prediction字段
pred_required_fields = ["original_question_id", "content", "reasoning_content"]
for field in pred_required_fields:
if field not in prediction:
return False, f"Missing field in prediction {i}: {field}"
# 验证content数组
content = prediction["content"]
reasoning_content = prediction["reasoning_content"]
if not isinstance(content, list) or len(content) != 4:
return False, f"content in prediction {i} must be a list with 4 items"
if not isinstance(reasoning_content, list):
return False, f"reasoning_content in prediction {i} must be a list"
# # reasoning_content可以为空列表,或者包含4个项目
# if len(reasoning_content) != 0 and len(reasoning_content) != 4:
# return False, f"reasoning_content in prediction {i} must be an empty list or contain 4 items"
# 验证question ID
if not isinstance(prediction["original_question_id"], int):
return False, f"question ID in prediction {i} must be an integer"
return True, "Submission format is valid"
def save_submission_file(submission_data: Dict[str, Any], submissions_dir: str = "./submissions") -> str:
"""保存提交文件到指定目录"""
# 确保submissions目录存在
os.makedirs(submissions_dir, exist_ok=True)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 获取模型名和组织名
model_name = submission_data.get("model_name", "UnknownModel")
model_name = model_name.replace(" ", "_").replace("/", "_").replace("\\", "_").replace("-", "_")
org_name = submission_data["submission_org"].replace(" ", "_").replace("/", "_").replace("\\", "_")
# 格式: submission_模型名_组织_时间戳.json
filename = f"submission_{model_name}_{org_name}_{timestamp}.json"
# 完整文件路径
file_path = os.path.join(submissions_dir, filename)
# 保存文件
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(submission_data, f, indent=2, ensure_ascii=False)
return file_path
def process_sage_submission_simple(submission_file, model_name=None, org_name=None, email=None) -> str:
"""
处理ATLAS基准提交文件 - 文件收集模式
只负责验证和保存,不进行评测
"""
try:
# 读取提交的文件
if submission_file is None:
return format_error("❌ No file uploaded. Please select a JSON file.")
# submission_file是文件路径字符串
try:
with open(submission_file, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
return format_error(f"❌ Error reading file: {str(e)}")
# 解析JSON
try:
submission_data = json.loads(content)
except json.JSONDecodeError as e:
return format_error(f"❌ Invalid JSON format: {str(e)}")
# 如果表单提供了模型名、组织名和邮箱,使用表单数据
if model_name:
submission_data["model_name"] = model_name.strip()
if org_name and email:
submission_data["submission_org"] = org_name.strip()
submission_data["submission_email"] = email.strip()
# 验证提交格式
is_valid, message = validate_sage_submission(submission_data)
if not is_valid:
return format_error(f"❌ Submission validation failed: {message}")
# 保存提交文件
try:
saved_path = save_submission_file(submission_data)
print(f"✅ Submission file saved to: {saved_path}")
# OSS上传策略
if OSS_AVAILABLE:
try:
# 使用OSS提交处理器
oss_handler = OSSSubmissionHandler()
result = oss_handler.process_sage_submission(submission_data, org_name, email)
# 如果OSS处理成功,直接返回结果
if "Submission successful" in result or "successful" in result.lower():
return result
else:
# OSS失败,继续使用备用模式
print(f"⚠️ OSS submission failed, using fallback mode: {result}")
except Exception as e:
print(f"⚠️ OSS submission exception, using fallback mode: {e}")
# 备用模式:本地保存
filename = os.path.basename(saved_path)
# 生成成功消息
org = submission_data["submission_org"]
email_addr = submission_data["submission_email"]
num_predictions = len(submission_data["predictions"])
success_msg = format_success(f"""
🎉 <strong>Submission successful!</strong><br><br>
📋 <strong>Submission Information:</strong><br>
• Organization: {org}<br>
• Email: {email_addr}<br>
• Number of predictions: {num_predictions} questions<br>
• Filename: {filename}<br><br>
🧪 Thank you for participating in the ATLAS scientific reasoning benchmark!
""")
return success_msg
except Exception as e:
return format_error(f"❌ Error saving submission file: {str(e)}")
except Exception as e:
return format_error(f"❌ Submission processing failed: {str(e)}")
def get_submission_stats(submissions_dir: str = "./submissions") -> Dict[str, Any]:
"""获取提交统计信息"""
if not os.path.exists(submissions_dir):
return {"total": 0, "recent": []}
submissions = []
for filename in os.listdir(submissions_dir):
if filename.startswith("submission_") and filename.endswith(".json"):
file_path = os.path.join(submissions_dir, filename)
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 提取信息
timestamp_str = filename.split("_")[-1].replace(".json", "")
try:
timestamp = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S")
formatted_time = timestamp.strftime("%Y-%m-%d %H:%M")
except:
formatted_time = timestamp_str
submissions.append({
"org": data.get("submission_org", "Unknown"),
"email": data.get("submission_email", ""),
"time": formatted_time,
"predictions": len(data.get("predictions", []))
})
except Exception:
continue
# 按时间排序,最新的在前
submissions.sort(key=lambda x: x["time"], reverse=True)
return {
"total": len(submissions),
"recent": submissions[:10] # 最近10个
}
# 移除了原有的HTTP推送函数,现在使用OSS模式