| """数据库 CRUD 操作接口""" |
|
|
| from datetime import datetime |
| from typing import List, Optional |
|
|
| from sqlalchemy import and_, select, update |
| from sqlalchemy.orm import Session |
|
|
| from qa_annotate.database.models import ( |
| AnnotationConfigModel, |
| AnnotationResultModel, |
| DatasetModel, |
| LlmAnalysisCacheModel, |
| ProjectModel, |
| QAPairModel, |
| QuestionTypeModel, |
| SeedQuestionModel, |
| SystemConfigModel, |
| UserModel, |
| ) |
| from qa_annotate.schema.annotation import AnnotationConfig, AnnotationResult |
| from qa_annotate.schema.dataset import Dataset, QAPair |
| from qa_annotate.schema.project import Project |
| from qa_annotate.schema.question_type import ( |
| QuestionType, |
| QuestionTypeCreate, |
| QuestionTypeUpdate, |
| ) |
| from qa_annotate.schema.seed_question import ( |
| SeedQuestion, |
| SeedQuestionCreate, |
| SeedQuestionUpdate, |
| SeedQuestionWithCreator, |
| ) |
| from qa_annotate.schema.system_config import SystemConfig, SystemConfigUpdate |
| from qa_annotate.schema.user import User, UserCreate, UserUpdate |
|
|
| |
|
|
|
|
| class AnnotationConfigCRUD: |
| """标注配置 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, config: AnnotationConfig) -> AnnotationConfig: |
| """创建标注配置""" |
| db_model = AnnotationConfigModel.from_pydantic(config) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, config_id: int) -> Optional[AnnotationConfig]: |
| """根据 ID 获取标注配置""" |
| db_model = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == config_id) |
| .first() |
| ) |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| annotation_type: Optional[str] = None, |
| ) -> List[AnnotationConfig]: |
| """获取所有标注配置(支持分页和过滤)""" |
| query = db.query(AnnotationConfigModel) |
|
|
| |
| if annotation_type: |
| query = query.filter( |
| AnnotationConfigModel.annotation_type == annotation_type |
| ) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def update( |
| db: Session, config_id: int, config: AnnotationConfig |
| ) -> Optional[AnnotationConfig]: |
| """更新标注配置""" |
| db_model = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == config_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return None |
|
|
| |
| |
| if isinstance(config.annotation_type, str): |
| annotation_type_value = config.annotation_type |
| else: |
| annotation_type_value = config.annotation_type.value |
|
|
| db_model.name = config.name |
| db_model.description = config.description |
| db_model.required = config.required |
| db_model.show_reason = config.show_reason |
| db_model.show_confidence = config.show_confidence |
| db_model.annotation_type = annotation_type_value |
| db_model.config_json = config.config.model_dump() |
| db_model.custom_fields_json = config.custom_fields |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, config_id: int) -> bool: |
| """删除标注配置(硬删除)""" |
| db_model = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == config_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return False |
|
|
| |
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def count(db: Session) -> int: |
| """获取标注配置总数""" |
| return db.query(AnnotationConfigModel).count() |
|
|
|
|
| |
|
|
|
|
| class AnnotationResultCRUD: |
| """标注结果 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, result: AnnotationResult) -> AnnotationResult: |
| """创建标注结果""" |
| db_model = AnnotationResultModel.from_pydantic(result) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, result_id: int) -> Optional[AnnotationResult]: |
| """根据 ID 获取标注结果""" |
| db_model = ( |
| db.query(AnnotationResultModel) |
| .filter(AnnotationResultModel.id == result_id) |
| .first() |
| ) |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| dataset_id: Optional[int] = None, |
| dataset_item_id: Optional[int] = None, |
| annotation_config_id: Optional[int] = None, |
| annotator_id: Optional[int] = None, |
| ) -> List[AnnotationResult]: |
| """获取所有标注结果(支持分页和过滤)""" |
| query = db.query(AnnotationResultModel) |
|
|
| if dataset_id: |
| query = query.filter(AnnotationResultModel.dataset_id == dataset_id) |
|
|
| if dataset_item_id: |
| query = query.filter( |
| AnnotationResultModel.dataset_item_id == dataset_item_id |
| ) |
|
|
| if annotation_config_id: |
| query = query.filter( |
| AnnotationResultModel.annotation_config_id == annotation_config_id |
| ) |
|
|
| if annotator_id: |
| query = query.filter(AnnotationResultModel.annotator_id == annotator_id) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def get_by_dataset_item( |
| db: Session, |
| dataset_id: int, |
| dataset_item_id: int, |
| ) -> List[AnnotationResult]: |
| """获取指定数据集项的所有标注结果""" |
| db_models = ( |
| db.query(AnnotationResultModel) |
| .filter( |
| and_( |
| AnnotationResultModel.dataset_id == dataset_id, |
| AnnotationResultModel.dataset_item_id == dataset_item_id, |
| ) |
| ) |
| .all() |
| ) |
| return [model.to_pydantic() for model in db_models] |
|
|
| @staticmethod |
| def update( |
| db: Session, result_id: int, result: AnnotationResult |
| ) -> Optional[AnnotationResult]: |
| """更新标注结果""" |
| db_model = ( |
| db.query(AnnotationResultModel) |
| .filter(AnnotationResultModel.id == result_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return None |
|
|
| |
| db_model.dataset_id = result.dataset_id |
| db_model.dataset_item_id = result.dataset_item_id |
| db_model.annotation_config_id = result.annotation_config_id |
| db_model.value_json = result.value.model_dump(exclude_none=True) |
| db_model.annotator_id = result.annotator_id |
| db_model.annotator_name = result.annotator_name |
| db_model.duration_seconds = result.duration_seconds |
| db_model.confidence = result.confidence |
| db_model.notes = result.notes |
| db_model.custom_fields_json = result.custom_fields |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, result_id: int) -> bool: |
| """删除标注结果""" |
| db_model = ( |
| db.query(AnnotationResultModel) |
| .filter(AnnotationResultModel.id == result_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def delete_by_dataset_item( |
| db: Session, |
| dataset_id: int, |
| dataset_item_id: int, |
| ) -> int: |
| """删除指定数据集项的所有标注结果,返回删除的数量""" |
| count = ( |
| db.query(AnnotationResultModel) |
| .filter( |
| and_( |
| AnnotationResultModel.dataset_id == dataset_id, |
| AnnotationResultModel.dataset_item_id == dataset_item_id, |
| ) |
| ) |
| .delete() |
| ) |
| db.commit() |
| return count |
|
|
| @staticmethod |
| def delete_by_config( |
| db: Session, |
| annotation_config_id: int, |
| ) -> int: |
| """删除指定标注配置的所有标注结果,返回删除的数量""" |
| count = ( |
| db.query(AnnotationResultModel) |
| .filter(AnnotationResultModel.annotation_config_id == annotation_config_id) |
| .delete() |
| ) |
| db.commit() |
| return count |
|
|
| @staticmethod |
| def count( |
| db: Session, |
| dataset_id: Optional[int] = None, |
| annotation_config_id: Optional[int] = None, |
| ) -> int: |
| """获取标注结果总数(支持过滤)""" |
| query = db.query(AnnotationResultModel) |
|
|
| if dataset_id: |
| query = query.filter(AnnotationResultModel.dataset_id == dataset_id) |
|
|
| if annotation_config_id: |
| query = query.filter( |
| AnnotationResultModel.annotation_config_id == annotation_config_id |
| ) |
|
|
| return query.count() |
|
|
|
|
| |
|
|
|
|
| class DatasetCRUD: |
| """数据集 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, dataset: Dataset) -> Dataset: |
| """创建数据集""" |
| db_model = DatasetModel.from_pydantic(dataset) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, dataset_id: int) -> Optional[Dataset]: |
| """根据 ID 获取数据集""" |
| db_model = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_by_name(db: Session, name: str) -> Optional[Dataset]: |
| """根据名称获取数据集""" |
| db_model = db.query(DatasetModel).filter(DatasetModel.name == name).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| name_search: Optional[str] = None, |
| ) -> List[Dataset]: |
| """获取所有数据集(支持分页和名称搜索)""" |
| query = db.query(DatasetModel) |
|
|
| if name_search: |
| query = query.filter(DatasetModel.name.contains(name_search)) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def update(db: Session, dataset_id: int, dataset: Dataset) -> Optional[Dataset]: |
| """更新数据集""" |
| db_model = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
|
|
| if not db_model: |
| return None |
|
|
| |
| db_model.name = dataset.name |
| db_model.description = dataset.description |
| db_model.version = dataset.version |
| db_model.status = dataset.status |
| db_model.tags_json = dataset.tags |
| db_model.category = dataset.category |
| db_model.creator = dataset.creator |
| db_model.creator_id = dataset.creator_id |
| db_model.annotator_id = dataset.annotator_id |
| db_model.annotator_name = dataset.annotator_name |
| db_model.source = dataset.source |
| db_model.source_url = dataset.source_url |
| db_model.metadata_json = dataset.metadata |
| db_model.display_extra_fields_json = dataset.display_extra_fields |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, dataset_id: int) -> bool: |
| """删除数据集(级联删除所有 QA 对)""" |
| db_model = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def claim_dataset( |
| db: Session, |
| dataset_id: int, |
| annotator_id: int, |
| annotator_name: str, |
| ) -> Optional[Dataset]: |
| """原子性地领取任务(将数据集分配给指定用户) |
| |
| 使用原子更新操作,确保并发安全: |
| - 只有当 annotator_id 为 None 时才会更新 |
| - 如果更新成功,返回更新后的数据集 |
| - 如果更新失败(已被其他用户领取),返回 None |
| |
| Args: |
| db: 数据库会话 |
| dataset_id: 数据集 ID |
| annotator_id: 标注者用户 ID |
| annotator_name: 标注者用户名 |
| |
| Returns: |
| 更新后的数据集对象,如果领取失败则返回 None |
| """ |
| |
| stmt = ( |
| update(DatasetModel) |
| .where( |
| and_( |
| DatasetModel.id == dataset_id, |
| DatasetModel.annotator_id.is_(None), |
| ) |
| ) |
| .values( |
| annotator_id=annotator_id, |
| annotator_name=annotator_name, |
| updated_at=datetime.now(), |
| ) |
| ) |
|
|
| result = db.execute(stmt) |
| db.commit() |
|
|
| |
| if result.rowcount == 0: |
| return None |
|
|
| |
| db_model = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if db_model: |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
| return None |
|
|
| @staticmethod |
| def release_dataset( |
| db: Session, |
| dataset_id: int, |
| annotator_id: int, |
| ) -> Optional[Dataset]: |
| """原子性地退回任务(将数据集从指定用户释放) |
| |
| 使用原子更新操作,确保并发安全: |
| - 只有当 annotator_id 匹配时才会更新 |
| - 如果更新成功,返回更新后的数据集 |
| - 如果更新失败(不属于该用户或不存在),返回 None |
| |
| Args: |
| db: 数据库会话 |
| dataset_id: 数据集 ID |
| annotator_id: 标注者用户 ID(用于验证任务是否属于该用户) |
| |
| Returns: |
| 更新后的数据集对象,如果退回失败则返回 None |
| """ |
| |
| stmt = ( |
| update(DatasetModel) |
| .where( |
| and_( |
| DatasetModel.id == dataset_id, |
| DatasetModel.annotator_id == annotator_id, |
| ) |
| ) |
| .values( |
| annotator_id=None, |
| annotator_name=None, |
| updated_at=datetime.now(), |
| ) |
| ) |
|
|
| result = db.execute(stmt) |
| db.commit() |
|
|
| |
| if result.rowcount == 0: |
| return None |
|
|
| |
| db_model = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if db_model: |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
| return None |
|
|
| @staticmethod |
| def count(db: Session) -> int: |
| """获取数据集总数""" |
| return db.query(DatasetModel).count() |
|
|
| @staticmethod |
| def get_items_count(db: Session, dataset_id: int) -> int: |
| """获取指定数据集的 QA 对数量""" |
| return ( |
| db.query(QAPairModel).filter(QAPairModel.dataset_id == dataset_id).count() |
| ) |
|
|
| @staticmethod |
| def get_available_datasets( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| user_species: Optional[str] = None, |
| ) -> List[Dataset]: |
| """获取可领取的数据集列表 |
| |
| 规则: |
| 1. 只返回 annotator_id 为空的数据集 |
| 2. 如果用户有 species: |
| - 返回匹配用户 species 的数据集(category == user_species) |
| - 或没有 category 的数据集(category 为 None) |
| 3. 如果用户没有 species: |
| - 只返回没有 category 的数据集(category 为 None) |
| - 排除有 category 的数据集,因为用户无法领取它们 |
| |
| Args: |
| db: 数据库会话 |
| skip: 跳过数量 |
| limit: 返回数量限制 |
| user_species: 用户的物种标签(可选) |
| |
| Returns: |
| 可领取的数据集列表 |
| """ |
| from sqlalchemy import or_ |
|
|
| |
| query = db.query(DatasetModel).filter(DatasetModel.annotator_id.is_(None)) |
|
|
| |
| |
| |
| conditions = [DatasetModel.category.is_(None)] |
| if user_species: |
| conditions.append(DatasetModel.category == user_species) |
|
|
| query = query.filter(or_(*conditions)) |
|
|
| |
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def get_by_annotator( |
| db: Session, |
| annotator_id: int, |
| skip: int = 0, |
| limit: int = 100, |
| ) -> List[Dataset]: |
| """根据标注者ID获取数据集列表 |
| |
| Args: |
| db: 数据库会话 |
| annotator_id: 标注者ID |
| skip: 跳过数量 |
| limit: 返回数量限制 |
| |
| Returns: |
| 分配给指定标注者的数据集列表 |
| """ |
| query = db.query(DatasetModel).filter(DatasetModel.annotator_id == annotator_id) |
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
|
|
| |
|
|
|
|
| class QAPairCRUD: |
| """QA对 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, qa_pair: QAPair) -> QAPair: |
| """创建 QA 对""" |
| db_model = QAPairModel.from_pydantic(qa_pair) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, qa_pair_id: int) -> Optional[QAPair]: |
| """根据 ID 获取 QA 对""" |
| db_model = db.query(QAPairModel).filter(QAPairModel.id == qa_pair_id).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_by_dataset( |
| db: Session, |
| dataset_id: int, |
| skip: int = 0, |
| limit: int = 100, |
| ) -> List[QAPair]: |
| """获取指定数据集的所有 QA 对(支持分页)""" |
| db_models = ( |
| db.query(QAPairModel) |
| .filter(QAPairModel.dataset_id == dataset_id) |
| .offset(skip) |
| .limit(limit) |
| .all() |
| ) |
| return [model.to_pydantic() for model in db_models] |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| dataset_id: Optional[int] = None, |
| question_search: Optional[str] = None, |
| ) -> List[QAPair]: |
| """获取所有 QA 对(支持分页和过滤)""" |
| query = db.query(QAPairModel) |
|
|
| if dataset_id: |
| query = query.filter(QAPairModel.dataset_id == dataset_id) |
|
|
| if question_search: |
| query = query.filter(QAPairModel.question.contains(question_search)) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def update(db: Session, qa_pair_id: int, qa_pair: QAPair) -> Optional[QAPair]: |
| """更新 QA 对""" |
| db_model = db.query(QAPairModel).filter(QAPairModel.id == qa_pair_id).first() |
|
|
| if not db_model: |
| return None |
|
|
| |
| db_model.dataset_id = qa_pair.dataset_id |
| db_model.question = qa_pair.question |
| db_model.answer = qa_pair.answer |
| db_model.updated_at = datetime.now() |
|
|
| |
| extra_fields = {} |
| for key, value in qa_pair.model_dump().items(): |
| if key not in ["id", "dataset_id", "question", "answer"]: |
| extra_fields[key] = value |
| db_model.extra_fields_json = extra_fields if extra_fields else None |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, qa_pair_id: int) -> bool: |
| """删除 QA 对""" |
| db_model = db.query(QAPairModel).filter(QAPairModel.id == qa_pair_id).first() |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def delete_by_dataset(db: Session, dataset_id: int) -> int: |
| """删除指定数据集的所有 QA 对,返回删除的数量""" |
| count = ( |
| db.query(QAPairModel).filter(QAPairModel.dataset_id == dataset_id).delete() |
| ) |
| db.commit() |
| return count |
|
|
| @staticmethod |
| def count( |
| db: Session, |
| dataset_id: Optional[int] = None, |
| ) -> int: |
| """获取 QA 对总数(支持过滤)""" |
| query = db.query(QAPairModel) |
|
|
| if dataset_id: |
| query = query.filter(QAPairModel.dataset_id == dataset_id) |
|
|
| return query.count() |
|
|
|
|
| |
|
|
|
|
| class DatasetAnnotationConfigCRUD: |
| """数据集与标注配置关联 CRUD 操作""" |
|
|
| @staticmethod |
| def associate( |
| db: Session, |
| dataset_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """关联数据集和标注配置""" |
| |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| |
| if config in dataset.annotation_configs: |
| return True |
|
|
| |
| dataset.annotation_configs.append(config) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def disassociate( |
| db: Session, |
| dataset_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """取消数据集和标注配置的关联""" |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| |
| if config in dataset.annotation_configs: |
| dataset.annotation_configs.remove(config) |
| db.commit() |
| return True |
|
|
| return False |
|
|
| @staticmethod |
| def get_datasets_by_config( |
| db: Session, |
| annotation_config_id: int, |
| ) -> List[Dataset]: |
| """获取使用指定标注配置的所有数据集""" |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
|
|
| if not config: |
| return [] |
|
|
| return [dataset.to_pydantic() for dataset in config.datasets] |
|
|
| @staticmethod |
| def get_configs_by_dataset( |
| db: Session, |
| dataset_id: int, |
| ) -> List[AnnotationConfig]: |
| """获取指定数据集关联的所有标注配置""" |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
|
|
| if not dataset: |
| return [] |
|
|
| return [config.to_pydantic() for config in dataset.annotation_configs] |
|
|
| @staticmethod |
| def set_dataset_configs( |
| db: Session, |
| dataset_id: int, |
| annotation_config_ids: List[int], |
| ) -> bool: |
| """设置数据集关联的标注配置(会替换现有关联)""" |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| |
| configs = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id.in_(annotation_config_ids)) |
| .all() |
| ) |
|
|
| if len(configs) != len(annotation_config_ids): |
| return False |
|
|
| |
| dataset.annotation_configs = configs |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def is_associated( |
| db: Session, |
| dataset_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """检查数据集和标注配置是否已关联""" |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| return config in dataset.annotation_configs |
|
|
| @staticmethod |
| def count_configs_by_dataset(db: Session, dataset_id: int) -> int: |
| """统计指定数据集关联的标注配置数量""" |
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return 0 |
|
|
| return len(dataset.annotation_configs) |
|
|
| @staticmethod |
| def count_datasets_by_config(db: Session, annotation_config_id: int) -> int: |
| """统计使用指定标注配置的数据集数量""" |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return 0 |
|
|
| return len(config.datasets) |
|
|
|
|
| |
|
|
|
|
| class UserCRUD: |
| """用户 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, user: UserCreate) -> User: |
| """创建用户""" |
| db_model = UserModel.from_pydantic(user) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, user_id: int) -> Optional[User]: |
| """根据 ID 获取用户""" |
| db_model = db.query(UserModel).filter(UserModel.id == user_id).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_by_username(db: Session, username: str) -> Optional[User]: |
| """根据用户名获取用户""" |
| db_model = db.query(UserModel).filter(UserModel.username == username).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def authenticate_user( |
| db: Session, username: str, password_hash: str, timestamp: int |
| ) -> Optional[User]: |
| """验证用户密码(使用SHA-256哈希+时间戳), 但不检查用户是否激活""" |
| from qa_annotate.utils.password import verify_password_with_timestamp |
|
|
| db_model = db.query(UserModel).filter(UserModel.username == username).first() |
|
|
| if not db_model: |
| return None |
|
|
| |
| if not verify_password_with_timestamp( |
| password_hash, db_model.hashed_password, timestamp |
| ): |
| return None |
|
|
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| is_active: Optional[bool] = None, |
| ) -> List[User]: |
| """获取所有用户(支持分页和过滤)""" |
| query = db.query(UserModel) |
|
|
| if is_active is not None: |
| query = query.filter(UserModel.is_active == is_active) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def update(db: Session, user_id: int, user_update: UserUpdate) -> Optional[User]: |
| """更新用户""" |
| db_model = db.query(UserModel).filter(UserModel.id == user_id).first() |
|
|
| if not db_model: |
| return None |
|
|
| db_model.update_from_pydantic(user_update) |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, user_id: int) -> bool: |
| """删除用户""" |
| db_model = db.query(UserModel).filter(UserModel.id == user_id).first() |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def count(db: Session, is_active: Optional[bool] = None) -> int: |
| """获取用户总数(支持过滤)""" |
| query = db.query(UserModel) |
|
|
| if is_active is not None: |
| query = query.filter(UserModel.is_active == is_active) |
|
|
| return query.count() |
|
|
|
|
| |
|
|
|
|
| class ProjectCRUD: |
| """项目 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, project: Project) -> Project: |
| """创建项目""" |
| db_model = ProjectModel.from_pydantic(project) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, project_id: int) -> Optional[Project]: |
| """根据 ID 获取项目""" |
| db_model = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_by_name(db: Session, name: str) -> Optional[Project]: |
| """根据名称获取项目""" |
| db_model = db.query(ProjectModel).filter(ProjectModel.name == name).first() |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| name_search: Optional[str] = None, |
| category: Optional[str] = None, |
| status: Optional[str] = None, |
| order_by: Optional[str] = "created_at", |
| order: Optional[str] = "desc", |
| ) -> List[Project]: |
| """获取所有项目(支持分页、过滤和排序)""" |
| query = db.query(ProjectModel) |
|
|
| if name_search: |
| query = query.filter(ProjectModel.name.contains(name_search)) |
|
|
| if category: |
| query = query.filter(ProjectModel.category == category) |
|
|
| if status: |
| query = query.filter(ProjectModel.status == status) |
|
|
| |
| if order_by: |
| |
| valid_order_fields = { |
| "id": ProjectModel.id, |
| "name": ProjectModel.name, |
| "created_at": ProjectModel.created_at, |
| "updated_at": ProjectModel.updated_at, |
| "version": ProjectModel.version, |
| "status": ProjectModel.status, |
| "category": ProjectModel.category, |
| } |
|
|
| order_field = valid_order_fields.get(order_by.lower()) |
| if order_field: |
| if order and order.lower() == "asc": |
| query = query.order_by(order_field.asc()) |
| else: |
| query = query.order_by(order_field.desc()) |
| else: |
| |
| query = query.order_by(ProjectModel.created_at.desc()) |
| else: |
| |
| query = query.order_by(ProjectModel.created_at.desc()) |
|
|
| results = query.offset(skip).limit(limit).all() |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def update(db: Session, project_id: int, project: Project) -> Optional[Project]: |
| """更新项目""" |
| db_model = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
|
|
| if not db_model: |
| return None |
|
|
| |
| db_model.name = project.name |
| db_model.description = project.description |
| db_model.version = project.version |
| db_model.status = project.status |
| db_model.tags_json = project.tags |
| db_model.category = project.category |
| db_model.source = project.source |
| db_model.source_url = project.source_url |
| db_model.metadata_json = project.metadata |
| db_model.display_extra_fields_json = project.display_extra_fields |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, project_id: int) -> bool: |
| """删除项目(数据集的project_id会设为NULL)""" |
| db_model = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def get_datasets_by_project( |
| db: Session, |
| project_id: int, |
| skip: int = 0, |
| limit: int = 100, |
| ) -> List[Dataset]: |
| """获取项目下的所有数据集(支持分页)""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return [] |
|
|
| datasets = project.datasets[skip : skip + limit] |
| return [dataset.to_pydantic() for dataset in datasets] |
|
|
| @staticmethod |
| def add_dataset_to_project( |
| db: Session, |
| project_id: int, |
| dataset_id: int, |
| ) -> bool: |
| """将数据集添加到项目""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| dataset.project_id = project_id |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def remove_dataset_from_project( |
| db: Session, |
| project_id: int, |
| dataset_id: int, |
| ) -> bool: |
| """从项目移除数据集""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| dataset = db.query(DatasetModel).filter(DatasetModel.id == dataset_id).first() |
| if not dataset: |
| return False |
|
|
| if dataset.project_id != project_id: |
| return False |
|
|
| dataset.project_id = None |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def count(db: Session) -> int: |
| """获取项目总数""" |
| return db.query(ProjectModel).count() |
|
|
| @staticmethod |
| def count_datasets_by_project(db: Session, project_id: int) -> int: |
| """获取指定项目下的数据集数量""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return 0 |
|
|
| return len(project.datasets) |
|
|
|
|
| |
|
|
|
|
| class ProjectAnnotationConfigCRUD: |
| """项目与标注配置关联 CRUD 操作""" |
|
|
| @staticmethod |
| def associate( |
| db: Session, |
| project_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """关联项目和标注配置""" |
| from qa_annotate.database.models import project_annotation_config_association |
|
|
| |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| |
| if config in project.annotation_configs: |
| return True |
|
|
| |
| stmt = ( |
| select(project_annotation_config_association.c.order) |
| .where(project_annotation_config_association.c.project_id == project_id) |
| .order_by(project_annotation_config_association.c.order.desc()) |
| .limit(1) |
| ) |
| result = db.execute(stmt).first() |
| next_order = (result[0] + 1) if result else 0 |
|
|
| |
| stmt = project_annotation_config_association.insert().values( |
| project_id=project_id, |
| annotation_config_id=annotation_config_id, |
| order=next_order, |
| ) |
| db.execute(stmt) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def disassociate( |
| db: Session, |
| project_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """取消项目和标注配置的关联""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| |
| if config in project.annotation_configs: |
| project.annotation_configs.remove(config) |
| db.commit() |
| return True |
|
|
| return False |
|
|
| @staticmethod |
| def get_projects_by_config( |
| db: Session, |
| annotation_config_id: int, |
| ) -> List[Project]: |
| """获取使用指定标注配置的所有项目""" |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
|
|
| if not config: |
| return [] |
|
|
| return [project.to_pydantic() for project in config.projects] |
|
|
| @staticmethod |
| def get_configs_by_project( |
| db: Session, |
| project_id: int, |
| ) -> List[AnnotationConfig]: |
| """获取指定项目关联的所有标注配置(按order排序)""" |
| from qa_annotate.database.models import project_annotation_config_association |
|
|
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
|
|
| if not project: |
| return [] |
|
|
| |
| configs_with_order = [] |
| for config in project.annotation_configs: |
| |
| stmt = select(project_annotation_config_association.c.order).where( |
| and_( |
| project_annotation_config_association.c.project_id == project_id, |
| project_annotation_config_association.c.annotation_config_id |
| == config.id, |
| ) |
| ) |
| result = db.execute(stmt).first() |
| order = result[0] if result else 0 |
| configs_with_order.append((order, config)) |
|
|
| |
| configs_with_order.sort(key=lambda x: x[0]) |
|
|
| return [config.to_pydantic() for _, config in configs_with_order] |
|
|
| @staticmethod |
| def set_project_configs( |
| db: Session, |
| project_id: int, |
| annotation_config_ids: List[int], |
| ) -> bool: |
| """设置项目关联的标注配置(会替换现有关联)""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| |
| configs = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id.in_(annotation_config_ids)) |
| .all() |
| ) |
|
|
| if len(configs) != len(annotation_config_ids): |
| return False |
|
|
| |
| project.annotation_configs = configs |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def is_associated( |
| db: Session, |
| project_id: int, |
| annotation_config_id: int, |
| ) -> bool: |
| """检查项目和标注配置是否已关联""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return False |
|
|
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return False |
|
|
| return config in project.annotation_configs |
|
|
| @staticmethod |
| def count_configs_by_project(db: Session, project_id: int) -> int: |
| """统计指定项目关联的标注配置数量""" |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return 0 |
|
|
| return len(project.annotation_configs) |
|
|
| @staticmethod |
| def count_projects_by_config(db: Session, annotation_config_id: int) -> int: |
| """统计使用指定标注配置的项目数量""" |
| config = ( |
| db.query(AnnotationConfigModel) |
| .filter(AnnotationConfigModel.id == annotation_config_id) |
| .first() |
| ) |
| if not config: |
| return 0 |
|
|
| return len(config.projects) |
|
|
| @staticmethod |
| def swap_config_order( |
| db: Session, |
| project_id: int, |
| config_id1: int, |
| config_id2: int, |
| ) -> bool: |
| """交换两个配置的顺序""" |
| from qa_annotate.database.models import project_annotation_config_association |
|
|
| |
| stmt1 = select(project_annotation_config_association.c.order).where( |
| and_( |
| project_annotation_config_association.c.project_id == project_id, |
| project_annotation_config_association.c.annotation_config_id |
| == config_id1, |
| ) |
| ) |
| result1 = db.execute(stmt1).first() |
|
|
| stmt2 = select(project_annotation_config_association.c.order).where( |
| and_( |
| project_annotation_config_association.c.project_id == project_id, |
| project_annotation_config_association.c.annotation_config_id |
| == config_id2, |
| ) |
| ) |
| result2 = db.execute(stmt2).first() |
|
|
| if not result1 or not result2: |
| return False |
|
|
| |
| order1 = result1[0] |
| order2 = result2[0] |
|
|
| stmt_update1 = ( |
| update(project_annotation_config_association) |
| .where( |
| and_( |
| project_annotation_config_association.c.project_id == project_id, |
| project_annotation_config_association.c.annotation_config_id |
| == config_id1, |
| ) |
| ) |
| .values(order=order2) |
| ) |
|
|
| stmt_update2 = ( |
| update(project_annotation_config_association) |
| .where( |
| and_( |
| project_annotation_config_association.c.project_id == project_id, |
| project_annotation_config_association.c.annotation_config_id |
| == config_id2, |
| ) |
| ) |
| .values(order=order1) |
| ) |
|
|
| db.execute(stmt_update1) |
| db.execute(stmt_update2) |
| db.commit() |
| return True |
|
|
|
|
| |
|
|
|
|
| class QuestionTypeCRUD: |
| """问题类型 CRUD 操作""" |
|
|
| @staticmethod |
| def create(db: Session, question_type: QuestionTypeCreate) -> QuestionType: |
| """创建问题类型""" |
| db_model = QuestionTypeModel.from_pydantic(question_type) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, type_id: int) -> Optional[QuestionType]: |
| """根据 ID 获取问题类型""" |
| db_model = ( |
| db.query(QuestionTypeModel).filter(QuestionTypeModel.id == type_id).first() |
| ) |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_by_type_subtype( |
| db: Session, type: str, subtype: str |
| ) -> Optional[QuestionType]: |
| """根据类型和亚类获取问题类型""" |
| db_model = ( |
| db.query(QuestionTypeModel) |
| .filter( |
| QuestionTypeModel.type == type, |
| QuestionTypeModel.subtype == subtype, |
| ) |
| .first() |
| ) |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all(db: Session, skip: int = 0, limit: int = 1000) -> List[QuestionType]: |
| """获取所有问题类型(支持分页)""" |
| results = ( |
| db.query(QuestionTypeModel) |
| .order_by( |
| QuestionTypeModel.type, |
| QuestionTypeModel.order, |
| QuestionTypeModel.subtype, |
| ) |
| .offset(skip) |
| .limit(limit) |
| .all() |
| ) |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def get_all_grouped(db: Session) -> dict: |
| """获取所有问题类型,按类型分组""" |
| results = ( |
| db.query(QuestionTypeModel) |
| .order_by( |
| QuestionTypeModel.type, |
| QuestionTypeModel.order, |
| QuestionTypeModel.subtype, |
| ) |
| .all() |
| ) |
| grouped = {} |
| for model in results: |
| type_name = model.type |
| if type_name not in grouped: |
| grouped[type_name] = [] |
| grouped[type_name].append(model.subtype) |
| return grouped |
|
|
| @staticmethod |
| def update( |
| db: Session, type_id: int, question_type_update: QuestionTypeUpdate |
| ) -> Optional[QuestionType]: |
| """更新问题类型""" |
| db_model = ( |
| db.query(QuestionTypeModel).filter(QuestionTypeModel.id == type_id).first() |
| ) |
| if not db_model: |
| return None |
|
|
| |
| if ( |
| question_type_update.type is not None |
| or question_type_update.subtype is not None |
| ): |
| new_type = ( |
| question_type_update.type |
| if question_type_update.type is not None |
| else db_model.type |
| ) |
| new_subtype = ( |
| question_type_update.subtype |
| if question_type_update.subtype is not None |
| else db_model.subtype |
| ) |
|
|
| |
| existing = ( |
| db.query(QuestionTypeModel) |
| .filter( |
| QuestionTypeModel.type == new_type, |
| QuestionTypeModel.subtype == new_subtype, |
| QuestionTypeModel.id != type_id, |
| ) |
| .first() |
| ) |
| if existing: |
| return None |
|
|
| if question_type_update.type is not None: |
| db_model.type = question_type_update.type |
| if question_type_update.subtype is not None: |
| db_model.subtype = question_type_update.subtype |
| if question_type_update.order is not None: |
| db_model.order = question_type_update.order |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, type_id: int) -> bool: |
| """删除问题类型""" |
| db_model = ( |
| db.query(QuestionTypeModel).filter(QuestionTypeModel.id == type_id).first() |
| ) |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def import_from_csv(db: Session, csv_path: str) -> dict: |
| """从CSV文件导入类型/亚类数据""" |
| import csv |
| from pathlib import Path |
|
|
| csv_file = Path(csv_path) |
| if not csv_file.exists(): |
| raise FileNotFoundError(f"CSV文件不存在: {csv_path}") |
|
|
| imported_count = 0 |
| skipped_count = 0 |
| errors = [] |
| current_type = None |
|
|
| with open(csv_file, "r", encoding="utf-8") as f: |
| reader = csv.DictReader(f) |
| for row_num, row in enumerate( |
| reader, start=2 |
| ): |
| try: |
| type_name = row.get("类型", "").strip() |
| subtype_name = row.get("亚类", "").strip() |
|
|
| |
| if type_name: |
| current_type = type_name |
|
|
| |
| if not current_type or not subtype_name: |
| skipped_count += 1 |
| continue |
|
|
| |
| existing = QuestionTypeCRUD.get_by_type_subtype( |
| db, current_type, subtype_name |
| ) |
| if existing: |
| skipped_count += 1 |
| continue |
|
|
| |
| question_type = QuestionTypeCreate( |
| type=current_type, subtype=subtype_name, order=0 |
| ) |
| QuestionTypeCRUD.create(db, question_type) |
| imported_count += 1 |
|
|
| except Exception as e: |
| errors.append(f"第{row_num}行: {str(e)}") |
|
|
| return { |
| "imported_count": imported_count, |
| "skipped_count": skipped_count, |
| "errors": errors, |
| } |
|
|
| @staticmethod |
| def count(db: Session) -> int: |
| """获取问题类型总数""" |
| return db.query(QuestionTypeModel).count() |
|
|
|
|
| |
|
|
|
|
| class SeedQuestionCRUD: |
| """种子问题 CRUD 操作""" |
|
|
| @staticmethod |
| def create( |
| db: Session, seed_question: SeedQuestionCreate, creator_id: int |
| ) -> SeedQuestion: |
| """创建种子问题""" |
| db_model = SeedQuestionModel.from_pydantic(seed_question, creator_id) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def get_by_id(db: Session, question_id: int) -> Optional[SeedQuestion]: |
| """根据 ID 获取种子问题""" |
| db_model = ( |
| db.query(SeedQuestionModel) |
| .filter(SeedQuestionModel.id == question_id) |
| .first() |
| ) |
| return db_model.to_pydantic() if db_model else None |
|
|
| @staticmethod |
| def get_all( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| creator_id: Optional[int] = None, |
| type: Optional[str] = None, |
| subtype: Optional[str] = None, |
| search: Optional[str] = None, |
| ) -> List[SeedQuestion]: |
| """获取所有种子问题(支持分页和过滤)""" |
| query = db.query(SeedQuestionModel) |
|
|
| if creator_id is not None: |
| query = query.filter(SeedQuestionModel.creator_id == creator_id) |
|
|
| if type: |
| query = query.filter(SeedQuestionModel.type == type) |
|
|
| if subtype: |
| query = query.filter(SeedQuestionModel.subtype == subtype) |
|
|
| if search: |
| query = query.filter(SeedQuestionModel.question.contains(search)) |
|
|
| results = ( |
| query.order_by(SeedQuestionModel.created_at.desc()) |
| .offset(skip) |
| .limit(limit) |
| .all() |
| ) |
| return [model.to_pydantic() for model in results] |
|
|
| @staticmethod |
| def get_all_with_creator( |
| db: Session, |
| skip: int = 0, |
| limit: int = 100, |
| creator_id: Optional[int] = None, |
| type: Optional[str] = None, |
| subtype: Optional[str] = None, |
| search: Optional[str] = None, |
| ) -> List[SeedQuestionWithCreator]: |
| """获取所有种子问题(包含创建者全名,支持分页和过滤)""" |
| query = db.query(SeedQuestionModel, UserModel.full_name).outerjoin( |
| UserModel, SeedQuestionModel.creator_id == UserModel.id |
| ) |
|
|
| if creator_id is not None: |
| query = query.filter(SeedQuestionModel.creator_id == creator_id) |
|
|
| if type: |
| query = query.filter(SeedQuestionModel.type == type) |
|
|
| if subtype: |
| query = query.filter(SeedQuestionModel.subtype == subtype) |
|
|
| if search: |
| query = query.filter(SeedQuestionModel.question.contains(search)) |
|
|
| results = ( |
| query.order_by(SeedQuestionModel.created_at.desc()) |
| .offset(skip) |
| .limit(limit) |
| .all() |
| ) |
|
|
| return [ |
| SeedQuestionWithCreator( |
| id=model.id, |
| question=model.question, |
| type=model.type, |
| subtype=model.subtype, |
| species_or_domain=model.species_or_domain, |
| model=model.model, |
| date=model.date, |
| is_verified=model.is_verified, |
| creator_id=model.creator_id, |
| creator_full_name=full_name, |
| created_at=model.created_at, |
| updated_at=model.updated_at, |
| ) |
| for model, full_name in results |
| ] |
|
|
| @staticmethod |
| def update( |
| db: Session, question_id: int, seed_question: SeedQuestionUpdate |
| ) -> Optional[SeedQuestion]: |
| """更新种子问题""" |
| db_model = ( |
| db.query(SeedQuestionModel) |
| .filter(SeedQuestionModel.id == question_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return None |
|
|
| |
| if seed_question.question is not None: |
| db_model.question = seed_question.question |
| if seed_question.type is not None: |
| db_model.type = seed_question.type |
| if seed_question.subtype is not None: |
| db_model.subtype = seed_question.subtype |
| if seed_question.species_or_domain is not None: |
| db_model.species_or_domain = seed_question.species_or_domain |
| if seed_question.model is not None: |
| db_model.model = seed_question.model |
| if seed_question.date is not None: |
| db_model.date = seed_question.date |
| if seed_question.is_verified is not None: |
| db_model.is_verified = seed_question.is_verified |
|
|
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return db_model.to_pydantic() |
|
|
| @staticmethod |
| def delete(db: Session, question_id: int) -> bool: |
| """删除种子问题""" |
| db_model = ( |
| db.query(SeedQuestionModel) |
| .filter(SeedQuestionModel.id == question_id) |
| .first() |
| ) |
|
|
| if not db_model: |
| return False |
|
|
| db.delete(db_model) |
| db.commit() |
| return True |
|
|
| @staticmethod |
| def count( |
| db: Session, |
| creator_id: Optional[int] = None, |
| type: Optional[str] = None, |
| subtype: Optional[str] = None, |
| ) -> int: |
| """获取种子问题总数(支持过滤)""" |
| query = db.query(SeedQuestionModel) |
|
|
| if creator_id is not None: |
| query = query.filter(SeedQuestionModel.creator_id == creator_id) |
|
|
| if type: |
| query = query.filter(SeedQuestionModel.type == type) |
|
|
| if subtype: |
| query = query.filter(SeedQuestionModel.subtype == subtype) |
|
|
| return query.count() |
|
|
| @staticmethod |
| def export_all(db: Session) -> List[SeedQuestionWithCreator]: |
| """导出所有种子问题(管理员用,包含创建者全名)""" |
| results = ( |
| db.query(SeedQuestionModel, UserModel.full_name) |
| .outerjoin(UserModel, SeedQuestionModel.creator_id == UserModel.id) |
| .order_by(SeedQuestionModel.created_at.desc()) |
| .all() |
| ) |
|
|
| return [ |
| SeedQuestionWithCreator( |
| id=model.id, |
| question=model.question, |
| type=model.type, |
| subtype=model.subtype, |
| species_or_domain=model.species_or_domain, |
| model=model.model, |
| date=model.date, |
| is_verified=model.is_verified, |
| creator_id=model.creator_id, |
| creator_full_name=full_name, |
| created_at=model.created_at, |
| updated_at=model.updated_at, |
| ) |
| for model, full_name in results |
| ] |
|
|
| @staticmethod |
| def create_batch( |
| db: Session, seed_questions: List[SeedQuestionCreate], creator_id: int |
| ) -> List[SeedQuestion]: |
| """批量创建种子问题""" |
| db_models = [ |
| SeedQuestionModel.from_pydantic(seed_question, creator_id) |
| for seed_question in seed_questions |
| ] |
| db.add_all(db_models) |
| db.commit() |
| for db_model in db_models: |
| db.refresh(db_model) |
| return [model.to_pydantic() for model in db_models] |
|
|
|
|
| |
|
|
|
|
| class SystemConfigCRUD: |
| """系统配置 CRUD 操作""" |
|
|
| @staticmethod |
| def get_by_key(db: Session, key: str) -> Optional[SystemConfig]: |
| """根据键获取系统配置""" |
| db_model = ( |
| db.query(SystemConfigModel).filter(SystemConfigModel.key == key).first() |
| ) |
| if not db_model: |
| return None |
| return SystemConfig( |
| id=db_model.id, |
| key=db_model.key, |
| value=db_model.value, |
| description=db_model.description, |
| created_at=db_model.created_at, |
| updated_at=db_model.updated_at, |
| ) |
|
|
| @staticmethod |
| def get_value(db: Session, key: str, default: str = None) -> Optional[str]: |
| """获取配置值(便捷方法)""" |
| config = SystemConfigCRUD.get_by_key(db, key) |
| if config: |
| return config.value |
| return default |
|
|
| @staticmethod |
| def set_value( |
| db: Session, key: str, value: str, description: str = None |
| ) -> SystemConfig: |
| """设置配置值(如果不存在则创建,存在则更新)""" |
| db_model = ( |
| db.query(SystemConfigModel).filter(SystemConfigModel.key == key).first() |
| ) |
| if db_model: |
| |
| if value is not None: |
| db_model.value = value |
| if description is not None: |
| db_model.description = description |
| db_model.updated_at = datetime.now() |
| else: |
| |
| db_model = SystemConfigModel( |
| key=key, |
| value=value, |
| description=description, |
| ) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return SystemConfig( |
| id=db_model.id, |
| key=db_model.key, |
| value=db_model.value, |
| description=db_model.description, |
| created_at=db_model.created_at, |
| updated_at=db_model.updated_at, |
| ) |
|
|
| @staticmethod |
| def update( |
| db: Session, key: str, config_update: SystemConfigUpdate |
| ) -> Optional[SystemConfig]: |
| """更新系统配置""" |
| db_model = ( |
| db.query(SystemConfigModel).filter(SystemConfigModel.key == key).first() |
| ) |
| if not db_model: |
| return None |
|
|
| if config_update.value is not None: |
| db_model.value = config_update.value |
| if config_update.description is not None: |
| db_model.description = config_update.description |
| db_model.updated_at = datetime.now() |
|
|
| db.commit() |
| db.refresh(db_model) |
| return SystemConfig( |
| id=db_model.id, |
| key=db_model.key, |
| value=db_model.value, |
| description=db_model.description, |
| created_at=db_model.created_at, |
| updated_at=db_model.updated_at, |
| ) |
|
|
| @staticmethod |
| def get_all(db: Session) -> List[SystemConfig]: |
| """获取所有系统配置""" |
| db_models = db.query(SystemConfigModel).all() |
| return [ |
| SystemConfig( |
| id=model.id, |
| key=model.key, |
| value=model.value, |
| description=model.description, |
| created_at=model.created_at, |
| updated_at=model.updated_at, |
| ) |
| for model in db_models |
| ] |
|
|
|
|
| |
|
|
|
|
| class AnnotationResultAnalysisCRUD: |
| """标注结果分析 CRUD 操作""" |
|
|
| @staticmethod |
| def get_project_annotation_stats( |
| db: Session, project_id: int |
| ) -> dict: |
| """获取项目的标注结果统计信息 |
| |
| 返回结构: |
| { |
| "total_datasets": int, |
| "total_items": int, |
| "total_annotations": int, |
| "completion_rate": float, # 完整标注率:完成所有配置标注的QA对占比 |
| "configs_stats": [ |
| { |
| "config_id": int, |
| "config_name": str, |
| "annotation_type": str, |
| "total_annotations": int, |
| "coverage": float, # 覆盖率(有多少QA对被标注) |
| "stats": dict # 按类型统计的数据 |
| } |
| ], |
| "notes_summary": [ |
| { |
| "config_name": str, |
| "notes": List[str], |
| "count": int |
| } |
| ] |
| } |
| """ |
| |
| project = db.query(ProjectModel).filter(ProjectModel.id == project_id).first() |
| if not project: |
| return None |
|
|
| datasets = project.datasets |
| dataset_ids = [d.id for d in datasets] |
|
|
| |
| configs = ProjectAnnotationConfigCRUD.get_configs_by_project(db, project_id) |
|
|
| |
| total_items = 0 |
| for dataset_id in dataset_ids: |
| total_items += QAPairCRUD.count(db, dataset_id=dataset_id) |
|
|
| |
| configs_stats = [] |
| all_notes = [] |
| |
| config_annotated_items = [] |
|
|
| for config in configs: |
| |
| results = AnnotationResultCRUD.get_all( |
| db=db, |
| skip=0, |
| limit=1000000, |
| annotation_config_id=config.id |
| ) |
|
|
| |
| filtered_results = [r for r in results if r.dataset_id in dataset_ids] |
|
|
| if not filtered_results: |
| |
| config_annotated_items.append(set()) |
| continue |
|
|
| |
| annotated_items = set(r.dataset_item_id for r in filtered_results) |
| config_annotated_items.append(annotated_items) |
| coverage = len(annotated_items) / total_items if total_items > 0 else 0 |
|
|
| |
| stats = AnnotationResultAnalysisCRUD._analyze_by_type( |
| filtered_results, config.annotation_type, config |
| ) |
|
|
| |
| notes_list = [r.notes for r in filtered_results if r.notes] |
| if notes_list: |
| all_notes.append({ |
| "config_name": config.name, |
| "notes": notes_list, |
| "count": len(notes_list) |
| }) |
|
|
| configs_stats.append({ |
| "config_id": config.id, |
| "config_name": config.name, |
| "annotation_type": config.annotation_type, |
| "total_annotations": len(filtered_results), |
| "coverage": coverage, |
| "stats": stats |
| }) |
|
|
| total_annotations = sum(s["total_annotations"] for s in configs_stats) |
|
|
| |
| if config_annotated_items and total_items > 0: |
| |
| annotated_items = set.union(*config_annotated_items) if len(config_annotated_items) > 0 else set() |
| annotated_count = len(annotated_items) |
|
|
| |
| |
| fully_annotated_items = set.intersection(*config_annotated_items) if len(config_annotated_items) > 0 else set() |
| fully_annotated_count = len(fully_annotated_items) |
| completion_rate = fully_annotated_count / total_items |
| else: |
| annotated_count = 0 |
| fully_annotated_count = 0 |
| completion_rate = 0 |
|
|
| return { |
| "total_datasets": len(datasets), |
| "total_items": total_items, |
| "annotated_items_count": annotated_count, |
| "fully_annotated_count": fully_annotated_count, |
| "completion_rate": completion_rate, |
| "configs_stats": configs_stats, |
| "notes_summary": all_notes |
| } |
|
|
| @staticmethod |
| def _analyze_by_type(results: List, annotation_type: str, config) -> dict: |
| """按标注类型进行统计分析""" |
| if annotation_type == "score": |
| return AnnotationResultAnalysisCRUD._analyze_score(results, config) |
| elif annotation_type in ["single_choice", "multi_choice"]: |
| return AnnotationResultAnalysisCRUD._analyze_choice(results, config) |
| elif annotation_type == "category": |
| return AnnotationResultAnalysisCRUD._analyze_category(results, config) |
| elif annotation_type == "binary": |
| return AnnotationResultAnalysisCRUD._analyze_binary(results, config) |
| elif annotation_type == "text": |
| return AnnotationResultAnalysisCRUD._analyze_text(results, config) |
| else: |
| return {"type": annotation_type, "count": len(results)} |
|
|
| @staticmethod |
| def _analyze_score(results: List, config) -> dict: |
| """分析评分标注""" |
| scores = [] |
| for r in results: |
| if r.value.score: |
| scores.append(r.value.score.score) |
|
|
| if not scores: |
| return {"type": "score", "count": 0} |
|
|
| return { |
| "type": "score", |
| "count": len(scores), |
| "average": sum(scores) / len(scores), |
| "min": min(scores), |
| "max": max(scores), |
| "distribution": AnnotationResultAnalysisCRUD._get_score_distribution( |
| scores, config.config.min_score, config.config.max_score |
| ) |
| } |
|
|
| @staticmethod |
| def _get_score_distribution(scores: List, min_score: int, max_score: int) -> dict: |
| """生成分数分布""" |
| distribution = {} |
| for score in scores: |
| key = str(int(score)) |
| distribution[key] = distribution.get(key, 0) + 1 |
| return distribution |
|
|
| @staticmethod |
| def _analyze_choice(results: List, config) -> dict: |
| """分析选择题标注""" |
| option_counts = {} |
| for r in results: |
| if r.value.choice: |
| for option_id in r.value.choice.selected_options: |
| option_counts[option_id] = option_counts.get(option_id, 0) + 1 |
|
|
| |
| option_labels = {} |
| if config.config.options: |
| for opt in config.config.options: |
| option_labels[opt.option_id] = opt.label |
|
|
| return { |
| "type": "choice", |
| "count": len(results), |
| "option_distribution": option_counts, |
| "option_labels": option_labels |
| } |
|
|
| @staticmethod |
| def _analyze_category(results: List, config) -> dict: |
| """分析分类标注""" |
| category_counts = {} |
| for r in results: |
| if r.value.category: |
| cat = r.value.category.category |
| category_counts[cat] = category_counts.get(cat, 0) + 1 |
|
|
| return { |
| "type": "category", |
| "count": len(results), |
| "category_distribution": category_counts |
| } |
|
|
| @staticmethod |
| def _analyze_binary(results: List, config) -> dict: |
| """分析二元标注""" |
| true_count = 0 |
| false_count = 0 |
| for r in results: |
| if r.value.binary: |
| if r.value.binary.value: |
| true_count += 1 |
| else: |
| false_count += 1 |
|
|
| return { |
| "type": "binary", |
| "count": len(results), |
| "true_count": true_count, |
| "false_count": false_count, |
| "true_ratio": true_count / len(results) if results else 0 |
| } |
|
|
| @staticmethod |
| def _analyze_text(results: List, config) -> dict: |
| """分析文本标注""" |
| lengths = [] |
| word_counts = [] |
| for r in results: |
| if r.value.text and r.value.text.text: |
| text = r.value.text.text |
| lengths.append(len(text)) |
| word_counts.append(len(text.split())) |
|
|
| if not lengths: |
| return {"type": "text", "count": 0} |
|
|
| return { |
| "type": "text", |
| "count": len(results), |
| "avg_length": sum(lengths) / len(lengths), |
| "max_length": max(lengths), |
| "min_length": min(lengths), |
| "avg_words": sum(word_counts) / len(word_counts) if word_counts else 0 |
| } |
|
|
|
|
| |
|
|
|
|
| class LlmAnalysisCacheCRUD: |
| """LLM 分析报告缓存 CRUD 操作""" |
|
|
| @staticmethod |
| def get_by_project( |
| db: Session, project_id: int, language: str | None = None |
| ) -> Optional[dict]: |
| """获取项目分析报告缓存(可按语言筛选)""" |
| query = db.query(LlmAnalysisCacheModel).filter( |
| LlmAnalysisCacheModel.project_id == project_id |
| ) |
| if language: |
| query = query.filter(LlmAnalysisCacheModel.language == language) |
| db_model = query.order_by(LlmAnalysisCacheModel.updated_at.desc()).first() |
| if not db_model: |
| return None |
| return { |
| "analysis": db_model.analysis_text, |
| "model_name": db_model.model_name, |
| "notes_count": db_model.notes_count, |
| "language": db_model.language, |
| "created_at": db_model.created_at, |
| "updated_at": db_model.updated_at, |
| } |
|
|
| @staticmethod |
| def save( |
| db: Session, |
| project_id: int, |
| analysis_text: str, |
| model_name: str, |
| notes_count: int, |
| language: str = "zh", |
| ) -> LlmAnalysisCacheModel: |
| """保存分析报告缓存(如果已有则更新,否则创建)""" |
| db_model = ( |
| db.query(LlmAnalysisCacheModel) |
| .filter( |
| LlmAnalysisCacheModel.project_id == project_id, |
| LlmAnalysisCacheModel.language == language, |
| ) |
| .first() |
| ) |
| if db_model: |
| db_model.analysis_text = analysis_text |
| db_model.model_name = model_name |
| db_model.notes_count = notes_count |
| db_model.language = language |
| db_model.updated_at = datetime.now() |
| else: |
| db_model = LlmAnalysisCacheModel( |
| project_id=project_id, |
| analysis_text=analysis_text, |
| model_name=model_name, |
| notes_count=notes_count, |
| language=language, |
| ) |
| db.add(db_model) |
| db.commit() |
| db.refresh(db_model) |
| return db_model |
|
|