Spaces:
Paused
Paused
| import logging | |
| from flask_login import current_user | |
| from flask_restful import Resource, fields, marshal_with, reqparse | |
| from flask_restful.inputs import int_range | |
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |
| from controllers.console import api | |
| from controllers.console.app.error import ( | |
| CompletionRequestError, | |
| ProviderModelCurrentlyNotSupportError, | |
| ProviderNotInitializeError, | |
| ProviderQuotaExceededError, | |
| ) | |
| from controllers.console.app.wraps import get_app_model | |
| from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError | |
| from controllers.console.wraps import ( | |
| account_initialization_required, | |
| cloud_edition_billing_resource_check, | |
| setup_required, | |
| ) | |
| from core.app.entities.app_invoke_entities import InvokeFrom | |
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |
| from core.model_runtime.errors.invoke import InvokeError | |
| from extensions.ext_database import db | |
| from fields.conversation_fields import annotation_fields, message_detail_fields | |
| from libs.helper import uuid_value | |
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |
| from libs.login import login_required | |
| from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback | |
| from services.annotation_service import AppAnnotationService | |
| from services.errors.conversation import ConversationNotExistsError | |
| from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError | |
| from services.message_service import MessageService | |
| class ChatMessageListApi(Resource): | |
| message_infinite_scroll_pagination_fields = { | |
| "limit": fields.Integer, | |
| "has_more": fields.Boolean, | |
| "data": fields.List(fields.Nested(message_detail_fields)), | |
| } | |
| def get(self, app_model): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |
| parser.add_argument("first_id", type=uuid_value, location="args") | |
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |
| args = parser.parse_args() | |
| conversation = ( | |
| db.session.query(Conversation) | |
| .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) | |
| .first() | |
| ) | |
| if not conversation: | |
| raise NotFound("Conversation Not Exists.") | |
| if args["first_id"]: | |
| first_message = ( | |
| db.session.query(Message) | |
| .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) | |
| .first() | |
| ) | |
| if not first_message: | |
| raise NotFound("First message not found") | |
| history_messages = ( | |
| db.session.query(Message) | |
| .filter( | |
| Message.conversation_id == conversation.id, | |
| Message.created_at < first_message.created_at, | |
| Message.id != first_message.id, | |
| ) | |
| .order_by(Message.created_at.desc()) | |
| .limit(args["limit"]) | |
| .all() | |
| ) | |
| else: | |
| history_messages = ( | |
| db.session.query(Message) | |
| .filter(Message.conversation_id == conversation.id) | |
| .order_by(Message.created_at.desc()) | |
| .limit(args["limit"]) | |
| .all() | |
| ) | |
| has_more = False | |
| if len(history_messages) == args["limit"]: | |
| current_page_first_message = history_messages[-1] | |
| rest_count = ( | |
| db.session.query(Message) | |
| .filter( | |
| Message.conversation_id == conversation.id, | |
| Message.created_at < current_page_first_message.created_at, | |
| Message.id != current_page_first_message.id, | |
| ) | |
| .count() | |
| ) | |
| if rest_count > 0: | |
| has_more = True | |
| history_messages = list(reversed(history_messages)) | |
| return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) | |
| class MessageFeedbackApi(Resource): | |
| def post(self, app_model): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("message_id", required=True, type=uuid_value, location="json") | |
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |
| args = parser.parse_args() | |
| message_id = str(args["message_id"]) | |
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() | |
| if not message: | |
| raise NotFound("Message Not Exists.") | |
| feedback = message.admin_feedback | |
| if not args["rating"] and feedback: | |
| db.session.delete(feedback) | |
| elif args["rating"] and feedback: | |
| feedback.rating = args["rating"] | |
| elif not args["rating"] and not feedback: | |
| raise ValueError("rating cannot be None when feedback not exists") | |
| else: | |
| feedback = MessageFeedback( | |
| app_id=app_model.id, | |
| conversation_id=message.conversation_id, | |
| message_id=message.id, | |
| rating=args["rating"], | |
| from_source="admin", | |
| from_account_id=current_user.id, | |
| ) | |
| db.session.add(feedback) | |
| db.session.commit() | |
| return {"result": "success"} | |
| class MessageAnnotationApi(Resource): | |
| def post(self, app_model): | |
| if not current_user.is_editor: | |
| raise Forbidden() | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("message_id", required=False, type=uuid_value, location="json") | |
| parser.add_argument("question", required=True, type=str, location="json") | |
| parser.add_argument("answer", required=True, type=str, location="json") | |
| parser.add_argument("annotation_reply", required=False, type=dict, location="json") | |
| args = parser.parse_args() | |
| annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) | |
| return annotation | |
| class MessageAnnotationCountApi(Resource): | |
| def get(self, app_model): | |
| count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() | |
| return {"count": count} | |
| class MessageSuggestedQuestionApi(Resource): | |
| def get(self, app_model, message_id): | |
| message_id = str(message_id) | |
| try: | |
| questions = MessageService.get_suggested_questions_after_answer( | |
| app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER | |
| ) | |
| except MessageNotExistsError: | |
| raise NotFound("Message not found") | |
| except ConversationNotExistsError: | |
| raise NotFound("Conversation not found") | |
| except ProviderTokenNotInitError as ex: | |
| raise ProviderNotInitializeError(ex.description) | |
| except QuotaExceededError: | |
| raise ProviderQuotaExceededError() | |
| except ModelCurrentlyNotSupportError: | |
| raise ProviderModelCurrentlyNotSupportError() | |
| except InvokeError as e: | |
| raise CompletionRequestError(e.description) | |
| except SuggestedQuestionsAfterAnswerDisabledError: | |
| raise AppSuggestedQuestionsAfterAnswerDisabledError() | |
| except Exception: | |
| logging.exception("internal server error.") | |
| raise InternalServerError() | |
| return {"data": questions} | |
| class MessageApi(Resource): | |
| def get(self, app_model, message_id): | |
| message_id = str(message_id) | |
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() | |
| if not message: | |
| raise NotFound("Message Not Exists.") | |
| return message | |
| api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions") | |
| api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages") | |
| api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks") | |
| api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations") | |
| api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count") | |
| api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message") | |