Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| ATLAS提交处理 - OSS模式 | |
| 使用阿里云OSS替代git/http提交方式 | |
| """ | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| from typing import Dict, Any, Tuple | |
| from pathlib import Path | |
| # 导入OSS提交处理器 | |
| try: | |
| from src.oss.oss_submission_handler import OSSSubmissionHandler | |
| OSS_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"⚠️ OSS module not available, using fallback mode: {e}") | |
| OSS_AVAILABLE = False | |
| def format_error(msg): | |
| return f"<p style='color: red; font-size: 16px;'>{msg}</p>" | |
| def format_success(msg): | |
| return f"<p style='color: green; font-size: 16px;'>{msg}</p>" | |
| def format_warning(msg): | |
| return f"<p style='color: orange; font-size: 16px;'>{msg}</p>" | |
| def validate_sage_submission(submission_data: Dict[str, Any]) -> Tuple[bool, str]: | |
| """验证ATLAS基准提交格式""" | |
| # 检查必需的顶级字段 | |
| required_fields = ["submission_org", "submission_email", "predictions"] | |
| for field in required_fields: | |
| if field not in submission_data: | |
| return False, f"Missing required field: {field}" | |
| # 验证邮箱格式(基本验证) | |
| email = submission_data["submission_email"] | |
| if "@" not in email or "." not in email: | |
| return False, "Invalid email format" | |
| # 验证predictions | |
| predictions = submission_data["predictions"] | |
| if not isinstance(predictions, list) or len(predictions) == 0: | |
| return False, "predictions must be a non-empty list" | |
| for i, prediction in enumerate(predictions): | |
| # 检查必需的prediction字段 | |
| pred_required_fields = ["original_question_id", "content", "reasoning_content"] | |
| for field in pred_required_fields: | |
| if field not in prediction: | |
| return False, f"Missing field in prediction {i}: {field}" | |
| # 验证content数组 | |
| content = prediction["content"] | |
| reasoning_content = prediction["reasoning_content"] | |
| if not isinstance(content, list) or len(content) != 4: | |
| return False, f"content in prediction {i} must be a list with 4 items" | |
| if not isinstance(reasoning_content, list): | |
| return False, f"reasoning_content in prediction {i} must be a list" | |
| # # reasoning_content可以为空列表,或者包含4个项目 | |
| # if len(reasoning_content) != 0 and len(reasoning_content) != 4: | |
| # return False, f"reasoning_content in prediction {i} must be an empty list or contain 4 items" | |
| # 验证question ID | |
| if not isinstance(prediction["original_question_id"], int): | |
| return False, f"question ID in prediction {i} must be an integer" | |
| return True, "Submission format is valid" | |
| def save_submission_file(submission_data: Dict[str, Any], submissions_dir: str = "./submissions") -> str: | |
| """保存提交文件到指定目录""" | |
| # 确保submissions目录存在 | |
| os.makedirs(submissions_dir, exist_ok=True) | |
| # 生成文件名 | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # 获取模型名和组织名 | |
| model_name = submission_data.get("model_name", "UnknownModel") | |
| model_name = model_name.replace(" ", "_").replace("/", "_").replace("\\", "_").replace("-", "_") | |
| org_name = submission_data["submission_org"].replace(" ", "_").replace("/", "_").replace("\\", "_") | |
| # 格式: submission_模型名_组织_时间戳.json | |
| filename = f"submission_{model_name}_{org_name}_{timestamp}.json" | |
| # 完整文件路径 | |
| file_path = os.path.join(submissions_dir, filename) | |
| # 保存文件 | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(submission_data, f, indent=2, ensure_ascii=False) | |
| return file_path | |
| def process_sage_submission_simple(submission_file, model_name=None, org_name=None, email=None) -> str: | |
| """ | |
| 处理ATLAS基准提交文件 - 文件收集模式 | |
| 只负责验证和保存,不进行评测 | |
| """ | |
| try: | |
| # 读取提交的文件 | |
| if submission_file is None: | |
| return format_error("❌ No file uploaded. Please select a JSON file.") | |
| # submission_file是文件路径字符串 | |
| try: | |
| with open(submission_file, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| except Exception as e: | |
| return format_error(f"❌ Error reading file: {str(e)}") | |
| # 解析JSON | |
| try: | |
| submission_data = json.loads(content) | |
| except json.JSONDecodeError as e: | |
| return format_error(f"❌ Invalid JSON format: {str(e)}") | |
| # 如果表单提供了模型名、组织名和邮箱,使用表单数据 | |
| if model_name: | |
| submission_data["model_name"] = model_name.strip() | |
| if org_name and email: | |
| submission_data["submission_org"] = org_name.strip() | |
| submission_data["submission_email"] = email.strip() | |
| # 验证提交格式 | |
| is_valid, message = validate_sage_submission(submission_data) | |
| if not is_valid: | |
| return format_error(f"❌ Submission validation failed: {message}") | |
| # 保存提交文件 | |
| try: | |
| saved_path = save_submission_file(submission_data) | |
| print(f"✅ Submission file saved to: {saved_path}") | |
| # OSS上传策略 | |
| if OSS_AVAILABLE: | |
| try: | |
| # 使用OSS提交处理器 | |
| oss_handler = OSSSubmissionHandler() | |
| result = oss_handler.process_sage_submission(submission_data, org_name, email) | |
| # 如果OSS处理成功,直接返回结果 | |
| if "Submission successful" in result or "successful" in result.lower(): | |
| return result | |
| else: | |
| # OSS失败,继续使用备用模式 | |
| print(f"⚠️ OSS submission failed, using fallback mode: {result}") | |
| except Exception as e: | |
| print(f"⚠️ OSS submission exception, using fallback mode: {e}") | |
| # 备用模式:本地保存 | |
| filename = os.path.basename(saved_path) | |
| # 生成成功消息 | |
| org = submission_data["submission_org"] | |
| email_addr = submission_data["submission_email"] | |
| num_predictions = len(submission_data["predictions"]) | |
| success_msg = format_success(f""" | |
| 🎉 <strong>Submission successful!</strong><br><br> | |
| 📋 <strong>Submission Information:</strong><br> | |
| • Organization: {org}<br> | |
| • Email: {email_addr}<br> | |
| • Number of predictions: {num_predictions} questions<br> | |
| • Filename: {filename}<br><br> | |
| 🧪 Thank you for participating in the ATLAS scientific reasoning benchmark! | |
| """) | |
| return success_msg | |
| except Exception as e: | |
| return format_error(f"❌ Error saving submission file: {str(e)}") | |
| except Exception as e: | |
| return format_error(f"❌ Submission processing failed: {str(e)}") | |
| def get_submission_stats(submissions_dir: str = "./submissions") -> Dict[str, Any]: | |
| """获取提交统计信息""" | |
| if not os.path.exists(submissions_dir): | |
| return {"total": 0, "recent": []} | |
| submissions = [] | |
| for filename in os.listdir(submissions_dir): | |
| if filename.startswith("submission_") and filename.endswith(".json"): | |
| file_path = os.path.join(submissions_dir, filename) | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # 提取信息 | |
| timestamp_str = filename.split("_")[-1].replace(".json", "") | |
| try: | |
| timestamp = datetime.strptime(timestamp_str, "%Y%m%d_%H%M%S") | |
| formatted_time = timestamp.strftime("%Y-%m-%d %H:%M") | |
| except: | |
| formatted_time = timestamp_str | |
| submissions.append({ | |
| "org": data.get("submission_org", "Unknown"), | |
| "email": data.get("submission_email", ""), | |
| "time": formatted_time, | |
| "predictions": len(data.get("predictions", [])) | |
| }) | |
| except Exception: | |
| continue | |
| # 按时间排序,最新的在前 | |
| submissions.sort(key=lambda x: x["time"], reverse=True) | |
| return { | |
| "total": len(submissions), | |
| "recent": submissions[:10] # 最近10个 | |
| } | |
| # 移除了原有的HTTP推送函数,现在使用OSS模式 | |