package database import ( "encoding/json" "fmt" "time" "gorm.io/datatypes" "gorm.io/gorm" ) // User 用户表 type User struct { Username string `gorm:"primaryKey;column:username" json:"username"` CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级 UpdatedAt int64 `gorm:"column:updated_at;not null" json:"updated_at"` // 时间戳毫秒级 } // Session 会话表(一个会话对应一个任务) type Session struct { ID string `gorm:"primaryKey;column:id" json:"id"` // 会话ID,也是任务ID Username string `gorm:"column:username;not null" json:"username"` Title string `gorm:"column:title" json:"title"` TaskType string `gorm:"column:task_type;not null" json:"task_type"` // 任务类型 Content string `gorm:"column:content;not null" json:"content"` // 任务内容 Params datatypes.JSON `gorm:"column:params" json:"params"` // 任务参数 Attachments datatypes.JSON `gorm:"column:attachments" json:"attachments"` // 附件 Status string `gorm:"column:status;not null;default:'todo'" json:"status"` // todo, doing, done, error AssignedAgent string `gorm:"column:assigned_agent" json:"assigned_agent"` // 分配的Agent CountryIsoCode string `gorm:"column:contry_iso_code" json:"countryIsoCode"` // 标识语言 StartedAt *int64 `gorm:"column:started_at" json:"started_at"` // 时间戳毫秒级 CompletedAt *int64 `gorm:"column:completed_at" json:"completed_at"` // 时间戳毫秒级 CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级 UpdatedAt int64 `gorm:"column:updated_at;not null" json:"updated_at"` // 时间戳毫秒级 // 关联关系 User User `gorm:"foreignKey:Username" json:"user"` Messages []TaskMessage `gorm:"foreignKey:SessionID" json:"messages"` // 直接关联到Session Share bool `gorm:"column:share;not null;default:false" json:"share"` } // TaskMessage 任务消息表(存储所有类型的事件消息) type TaskMessage struct { ID string `gorm:"primaryKey;column:id" json:"id"` // 消息ID(前端生成的对话ID) SessionID string `gorm:"column:session_id;not null" json:"session_id"` // 会话ID(也是任务ID) Type string `gorm:"column:type;not null" json:"type"` // liveStatus, planUpdate, statusUpdate, toolUsed等 EventData datatypes.JSON `gorm:"column:event_data;not null" json:"event_data"` // 存储事件的具体数据 Timestamp int64 `gorm:"column:timestamp;not null" json:"timestamp"` CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级 // 关联关系 Session Session `gorm:"foreignKey:SessionID" json:"session"` } // TaskStore 任务数据存储 type TaskStore struct { db *gorm.DB } // NewTaskStore 创建新的TaskStore实例 func NewTaskStore(db *gorm.DB) *TaskStore { return &TaskStore{db: db} } // Init 自动迁移任务相关表结构 func (s *TaskStore) Init() error { return s.db.AutoMigrate(&User{}, &Session{}, &TaskMessage{}) } // CreateUser 创建用户 func (s *TaskStore) CreateUser(user *User) error { now := time.Now().UnixMilli() user.CreatedAt = now user.UpdatedAt = now return s.db.Create(user).Error } // ResetRunningTasks 重置运行中的任务为失败 func (s *TaskStore) ResetRunningTasks() error { return s.db.Model(&Session{}).Where("status = 'doing' or status = 'failed'").Updates(map[string]interface{}{ "status": "error", "updated_at": time.Now().UnixMilli(), }).Error } // GetUser 获取用户信息 func (s *TaskStore) GetUser(username string) (*User, error) { var user User err := s.db.First(&user, "username = ?", username).Error if err != nil { return nil, err } return &user, nil } // CreateSession 创建会话(包含任务信息) func (s *TaskStore) CreateSession(session *Session) error { now := time.Now().UnixMilli() session.CreatedAt = now session.UpdatedAt = now return s.db.Create(session).Error } // GetSession 获取会话信息 func (s *TaskStore) GetSession(id string) (*Session, error) { var session Session err := s.db.Preload("User").Preload("Messages").First(&session, "id = ?", id).Error if err != nil { return nil, err } return &session, nil } // SetShare 设置会话共享 func (s *TaskStore) SetShare(sessionID string, share bool) error { return s.db.Model(&Session{}).Where("id = ?", sessionID).Update("share", share).Error } // UpdateSessionStatus 更新会话状态 func (s *TaskStore) UpdateSessionStatus(id string, status string) error { now := time.Now().UnixMilli() updates := map[string]interface{}{ "status": status, "updated_at": now, } if status == "doing" { updates["started_at"] = &now } else if status == "done" { updates["completed_at"] = &now } return s.db.Model(&Session{}).Where("id = ?", id).Updates(updates).Error } // UpdateSessionAssignedAgent 更新会话的分配Agent和开始时间 func (s *TaskStore) UpdateSessionAssignedAgent(sessionID string, agentID string) error { now := time.Now().UnixMilli() updates := map[string]interface{}{ "assigned_agent": agentID, "status": "doing", "started_at": &now, } return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error } // UpdateSession 更新会话信息 func (s *TaskStore) UpdateSession(sessionID string, updates map[string]interface{}) error { // 添加更新时间 updates["updated_at"] = time.Now().UnixMilli() return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error } // DeleteSession 删除会话 func (s *TaskStore) DeleteSession(sessionID string) error { return s.db.Delete(&Session{}, "id = ?", sessionID).Error } // DeleteSessionMessages 删除会话的所有消息 func (s *TaskStore) DeleteSessionMessages(sessionID string) error { return s.db.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error } // DeleteSessionWithMessages 使用事务删除会话及其所有消息 func (s *TaskStore) DeleteSessionWithMessages(sessionID string) error { return s.db.Transaction(func(tx *gorm.DB) error { // 1. 删除会话的所有消息 if err := tx.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error; err != nil { return fmt.Errorf("删除会话消息失败: %v", err) } // 2. 删除会话记录 if err := tx.Delete(&Session{}, "id = ?", sessionID).Error; err != nil { return fmt.Errorf("删除会话记录失败: %v", err) } return nil }) } // CreateTaskMessage 创建任务消息 func (s *TaskStore) CreateTaskMessage(message *TaskMessage) error { now := time.Now().UnixMilli() message.CreatedAt = now return s.db.Create(message).Error } // GetSessionMessages 获取会话的所有消息 func (s *TaskStore) GetSessionMessages(sessionID string) ([]*TaskMessage, error) { var messages []*TaskMessage err := s.db.Where("session_id = ?", sessionID).Order("timestamp ASC").Find(&messages).Error if err != nil { return nil, err } return messages, nil } // GetUserSessions 获取用户的所有会话 func (s *TaskStore) GetUserSessions(username string) ([]*Session, error) { var sessions []*Session err := s.db.Where("username = ?", username).Order("created_at DESC").Find(&sessions).Error if err != nil { return nil, err } return sessions, nil } // GetUserSessionsByType 获取用户的会话,支持可选的任务类型过滤 func (s *TaskStore) GetUserSessionsByType(username string, taskType string) ([]*Session, error) { query := s.db.Where("username = ?", username) // 如果指定了任务类型,添加类型过滤 if taskType != "" { query = query.Where("task_type = ?", taskType) } var sessions []*Session err := query.Order("created_at DESC").Find(&sessions).Error if err != nil { return nil, err } return sessions, nil } // StoreEvent 存储事件消息 func (s *TaskStore) StoreEvent(id string, sessionID string, eventType string, eventData interface{}, timestamp int64) error { // 将事件数据序列化为JSON eventJSON, err := json.Marshal(eventData) if err != nil { return err } message := &TaskMessage{ ID: id, SessionID: sessionID, Type: eventType, EventData: datatypes.JSON(eventJSON), Timestamp: timestamp, } return s.CreateTaskMessage(message) } // GetSessionEvents 获取会话的所有事件 func (s *TaskStore) GetSessionEvents(sessionID string) ([]*TaskMessage, error) { return s.GetSessionMessages(sessionID) } // GetSessionEventsByType 根据类型获取会话事件 func (s *TaskStore) GetSessionEventsByType(sessionID string, eventType string) ([]*TaskMessage, error) { var messages []*TaskMessage err := s.db.Where("session_id = ? AND type = ?", sessionID, eventType).Order("timestamp ASC").Find(&messages).Error if err != nil { return nil, err } return messages, nil } // SearchUserSessionsSimple 使用单个查询参数搜索用户的会话,支持在title、content、task_type字段中搜索 func (s *TaskStore) SearchUserSessionsSimple(username string, searchParams SimpleSearchParams) ([]*Session, int64, error) { query := s.db.Model(&Session{}).Where("username = ?", username) // 如果指定了任务类型,添加类型过滤 if searchParams.TaskType != "" { query = query.Where("task_type = ?", searchParams.TaskType) } // 如果有查询关键词,在多个字段中搜索 if searchParams.Query != "" { query = query.Where("title LIKE ? OR content LIKE ? OR task_type LIKE ?", "%"+searchParams.Query+"%", "%"+searchParams.Query+"%", "%"+searchParams.Query+"%") } // 获取总数 var total int64 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 应用分页和排序 var sessions []*Session err := query.Order("created_at DESC"). Offset((searchParams.Page - 1) * searchParams.PageSize). Limit(searchParams.PageSize). Find(&sessions).Error if err != nil { return nil, 0, err } return sessions, total, nil } // SimpleSearchParams 简化搜索参数结构 type SimpleSearchParams struct { Query string `json:"query"` // 查询关键词,将在title、content、task_type字段中搜索 TaskType string `json:"task_type"` // 任务类型过滤 Page int `json:"page"` // 页码 PageSize int `json:"page_size"` // 每页大小 } // generateMessageID 生成消息ID func generateMessageID() string { return time.Now().Format("20060102150405") + "_" + fmt.Sprintf("%d", time.Now().UnixNano()) }