curry tang commited on
Commit ·
b6ec8b9
1
Parent(s): 99a9a6e
update
Browse files
app.py
CHANGED
|
@@ -2,7 +2,10 @@ import gradio as gr
|
|
| 2 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 3 |
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
|
| 4 |
from config import settings
|
| 5 |
-
from prompts import
|
|
|
|
|
|
|
|
|
|
| 6 |
from langchain_core.prompts import PromptTemplate
|
| 7 |
from log import logging
|
| 8 |
from utils import convert_image_to_base64
|
|
@@ -21,6 +24,12 @@ provider_model_map = dict(
|
|
| 21 |
Tongyi=tongyi_llm,
|
| 22 |
)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
support_vision_models = [
|
| 25 |
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
|
| 26 |
'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
|
|
@@ -33,29 +42,39 @@ def get_default_chat():
|
|
| 33 |
return _llm.get_chat_engine()
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def predict(message, history, _chat, _current_assistant: str):
|
| 37 |
logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
|
| 38 |
files_len = len(message.files)
|
| 39 |
-
|
| 40 |
-
_chat = get_default_chat()
|
| 41 |
if files_len > 0:
|
| 42 |
if _chat.model_name not in support_vision_models:
|
| 43 |
raise gr.Error("当前模型不支持图片,请更换模型。")
|
| 44 |
|
| 45 |
_lc_history = []
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
assistant_prompt = backend_developer_prompt
|
| 49 |
-
if _current_assistant == '数据分析师':
|
| 50 |
-
assistant_prompt = analyst_prompt
|
| 51 |
-
_lc_history.append(SystemMessage(content=assistant_prompt))
|
| 52 |
-
|
| 53 |
-
for his_msg in history:
|
| 54 |
-
if his_msg['role'] == 'user':
|
| 55 |
-
if not hasattr(his_msg['content'], 'file'):
|
| 56 |
-
_lc_history.append(HumanMessage(content=his_msg['content']))
|
| 57 |
-
if his_msg['role'] == 'assistant':
|
| 58 |
-
_lc_history.append(AIMessage(content=his_msg['content']))
|
| 59 |
|
| 60 |
if files_len == 0:
|
| 61 |
_lc_history.append(HumanMessage(content=message.text))
|
|
@@ -81,8 +100,7 @@ def update_chat(_provider: str, _model: str, _temperature: float, _max_tokens: i
|
|
| 81 |
|
| 82 |
|
| 83 |
def explain_code(_code_type: str, _code: str, _chat):
|
| 84 |
-
|
| 85 |
-
_chat = get_default_chat()
|
| 86 |
chat_messages = [
|
| 87 |
SystemMessage(content=explain_code_template),
|
| 88 |
HumanMessage(content=_code),
|
|
@@ -94,8 +112,7 @@ def explain_code(_code_type: str, _code: str, _chat):
|
|
| 94 |
|
| 95 |
|
| 96 |
def optimize_code(_code_type: str, _code: str, _chat):
|
| 97 |
-
|
| 98 |
-
_chat = get_default_chat()
|
| 99 |
prompt = PromptTemplate.from_template(optimize_code_template)
|
| 100 |
prompt = prompt.format(code_type=_code_type)
|
| 101 |
chat_messages = [
|
|
@@ -109,8 +126,7 @@ def optimize_code(_code_type: str, _code: str, _chat):
|
|
| 109 |
|
| 110 |
|
| 111 |
def debug_code(_code_type: str, _code: str, _chat):
|
| 112 |
-
|
| 113 |
-
_chat = get_default_chat()
|
| 114 |
prompt = PromptTemplate.from_template(debug_code_template)
|
| 115 |
prompt = prompt.format(code_type=_code_type)
|
| 116 |
chat_messages = [
|
|
@@ -124,8 +140,7 @@ def debug_code(_code_type: str, _code: str, _chat):
|
|
| 124 |
|
| 125 |
|
| 126 |
def function_gen(_code_type: str, _code: str, _chat):
|
| 127 |
-
|
| 128 |
-
_chat = get_default_chat()
|
| 129 |
prompt = PromptTemplate.from_template(function_gen_template)
|
| 130 |
prompt = prompt.format(code_type=_code_type)
|
| 131 |
chat_messages = [
|
|
@@ -139,8 +154,7 @@ def function_gen(_code_type: str, _code: str, _chat):
|
|
| 139 |
|
| 140 |
|
| 141 |
def translate_doc(_language_input, _language_output, _doc, _chat):
|
| 142 |
-
|
| 143 |
-
_chat = get_default_chat()
|
| 144 |
prompt = PromptTemplate.from_template(translate_doc_template)
|
| 145 |
prompt = prompt.format(language_input=_language_input, language_output=_language_output)
|
| 146 |
chat_messages = [
|
|
|
|
| 2 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 3 |
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
|
| 4 |
from config import settings
|
| 5 |
+
from prompts import (
|
| 6 |
+
web_prompt, explain_code_template, optimize_code_template, debug_code_template,
|
| 7 |
+
function_gen_template, translate_doc_template, backend_developer_prompt, analyst_prompt
|
| 8 |
+
)
|
| 9 |
from langchain_core.prompts import PromptTemplate
|
| 10 |
from log import logging
|
| 11 |
from utils import convert_image_to_base64
|
|
|
|
| 24 |
Tongyi=tongyi_llm,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
system_prompt_map = {
|
| 28 |
+
"前端开发助手": web_prompt,
|
| 29 |
+
"后端开发助手": backend_developer_prompt,
|
| 30 |
+
"数据分析师": analyst_prompt,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
support_vision_models = [
|
| 34 |
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
|
| 35 |
'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
|
|
|
|
| 42 |
return _llm.get_chat_engine()
|
| 43 |
|
| 44 |
|
| 45 |
+
def get_chat_or_default(chat):
|
| 46 |
+
if chat is None:
|
| 47 |
+
chat = get_default_chat()
|
| 48 |
+
return chat
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def convert_history_to_langchain_history(history, lc_history):
|
| 52 |
+
for his_msg in history:
|
| 53 |
+
if his_msg['role'] == 'user':
|
| 54 |
+
if not hasattr(his_msg['content'], 'file'):
|
| 55 |
+
lc_history.append(HumanMessage(content=his_msg['content']))
|
| 56 |
+
if his_msg['role'] == 'assistant':
|
| 57 |
+
lc_history.append(AIMessage(content=his_msg['content']))
|
| 58 |
+
return lc_history
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def append_system_prompt(key: str, lc_history):
|
| 62 |
+
prompt = system_prompt_map[key]
|
| 63 |
+
lc_history.append(SystemMessage(content=prompt))
|
| 64 |
+
return lc_history
|
| 65 |
+
|
| 66 |
+
|
| 67 |
def predict(message, history, _chat, _current_assistant: str):
|
| 68 |
logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
|
| 69 |
files_len = len(message.files)
|
| 70 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 71 |
if files_len > 0:
|
| 72 |
if _chat.model_name not in support_vision_models:
|
| 73 |
raise gr.Error("当前模型不支持图片,请更换模型。")
|
| 74 |
|
| 75 |
_lc_history = []
|
| 76 |
+
_lc_history = append_system_prompt(_current_assistant, _lc_history)
|
| 77 |
+
_lc_history = convert_history_to_langchain_history(history, _lc_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
if files_len == 0:
|
| 80 |
_lc_history.append(HumanMessage(content=message.text))
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
def explain_code(_code_type: str, _code: str, _chat):
|
| 103 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 104 |
chat_messages = [
|
| 105 |
SystemMessage(content=explain_code_template),
|
| 106 |
HumanMessage(content=_code),
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def optimize_code(_code_type: str, _code: str, _chat):
|
| 115 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 116 |
prompt = PromptTemplate.from_template(optimize_code_template)
|
| 117 |
prompt = prompt.format(code_type=_code_type)
|
| 118 |
chat_messages = [
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
def debug_code(_code_type: str, _code: str, _chat):
|
| 129 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 130 |
prompt = PromptTemplate.from_template(debug_code_template)
|
| 131 |
prompt = prompt.format(code_type=_code_type)
|
| 132 |
chat_messages = [
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
def function_gen(_code_type: str, _code: str, _chat):
|
| 143 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 144 |
prompt = PromptTemplate.from_template(function_gen_template)
|
| 145 |
prompt = prompt.format(code_type=_code_type)
|
| 146 |
chat_messages = [
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
def translate_doc(_language_input, _language_output, _doc, _chat):
|
| 157 |
+
_chat = get_chat_or_default(_chat)
|
|
|
|
| 158 |
prompt = PromptTemplate.from_template(translate_doc_template)
|
| 159 |
prompt = prompt.format(language_input=_language_input, language_output=_language_output)
|
| 160 |
chat_messages = [
|
llm.py
CHANGED
|
@@ -60,8 +60,8 @@ class DeepSeekLLM(BaseLLM):
|
|
| 60 |
|
| 61 |
class OpenRouterLLM(BaseLLM):
|
| 62 |
_support_models = [
|
| 63 |
-
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', '
|
| 64 |
-
'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct',
|
| 65 |
'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
|
| 66 |
'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
|
| 67 |
'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
|
|
|
|
| 60 |
|
| 61 |
class OpenRouterLLM(BaseLLM):
|
| 62 |
_support_models = [
|
| 63 |
+
'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o-2024-08-06',
|
| 64 |
+
'google/gemini-pro-1.5-exp', 'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct',
|
| 65 |
'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
|
| 66 |
'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
|
| 67 |
'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
|