SIRILSAM77 commited on
Commit
d6be6dd
·
verified ·
1 Parent(s): 709898e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from typing import List, Union
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from langchain.callbacks import get_openai_callback
5
+ from langchain.chat_models import ChatOpenAI
6
+ from langchain.schema import (SystemMessage, HumanMessage, AIMessage)
7
+ from langchain.llms import LlamaCpp
8
+ from langchain.callbacks.manager import CallbackManager
9
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
+ import streamlit as st
11
+
12
+
13
+ def init_page() -> None:
14
+ st.set_page_config(
15
+ page_title="Personal ChatGPT"
16
+ )
17
+ st.header("Personal ChatGPT")
18
+ st.sidebar.title("Options")
19
+
20
+
21
+ def init_messages() -> None:
22
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
23
+ if clear_button or "messages" not in st.session_state:
24
+ st.session_state.messages = [
25
+ SystemMessage(
26
+ content="You are a helpful AI assistant. Reply your answer in mardkown format.")
27
+ ]
28
+ st.session_state.costs = []
29
+
30
+
31
+ def select_llm() -> Union[ChatOpenAI, LlamaCpp]:
32
+ model_name = st.sidebar.radio("Choose LLM:",
33
+ ("gpt-3.5-turbo-0613", "gpt-4",
34
+ "llama-2-7b-chat.ggmlv3.q2_K"))
35
+ temperature = st.sidebar.slider("Temperature:", min_value=0.0,
36
+ max_value=1.0, value=0.0, step=0.01)
37
+ if model_name.startswith("gpt-"):
38
+ return ChatOpenAI(temperature=temperature, model_name=model_name)
39
+ elif model_name.startswith("llama-2-"):
40
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
41
+ return LlamaCpp(
42
+ model_path=f"/app/models/{model_name}.bin",
43
+ input={"temperature": temperature,
44
+ "max_length": 2000,
45
+ "top_p": 1
46
+ },
47
+ callback_manager=callback_manager,
48
+ verbose=False, # True
49
+ )
50
+
51
+
52
+ def get_answer(llm, messages) -> tuple[str, float]:
53
+ if isinstance(llm, ChatOpenAI):
54
+ with get_openai_callback() as cb:
55
+ answer = llm(messages)
56
+ return answer.content, cb.total_cost
57
+ if isinstance(llm, LlamaCpp):
58
+ return llm(llama_v2_prompt(convert_langchainschema_to_dict(messages))), 0.0
59
+
60
+
61
+ def find_role(message: Union[SystemMessage, HumanMessage, AIMessage]) -> str:
62
+ """
63
+ Identify role name from langchain.schema object.
64
+ """
65
+ if isinstance(message, SystemMessage):
66
+ return "system"
67
+ if isinstance(message, HumanMessage):
68
+ return "user"
69
+ if isinstance(message, AIMessage):
70
+ return "assistant"
71
+ raise TypeError("Unknown message type.")
72
+
73
+
74
+ def convert_langchainschema_to_dict(
75
+ messages: List[Union[SystemMessage, HumanMessage, AIMessage]]) \
76
+ -> List[dict]:
77
+ """
78
+ Convert the chain of chat messages in list of langchain.schema format to
79
+ list of dictionary format.
80
+ """
81
+ return [{"role": find_role(message),
82
+ "content": message.content
83
+ } for message in messages]
84
+
85
+
86
+ def llama_v2_prompt(messages: List[dict]) -> str:
87
+ """
88
+ Convert the messages in list of dictionary format to Llama2 compliant format.
89
+ """
90
+ B_INST, E_INST = "[INST]", "[/INST]"
91
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
92
+ BOS, EOS = "<s>", "</s>"
93
+ DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
94
+
95
+ if messages[0]["role"] != "system":
96
+ messages = [
97
+ {
98
+ "role": "system",
99
+ "content": DEFAULT_SYSTEM_PROMPT,
100
+ }
101
+ ] + messages
102
+ messages = [
103
+ {
104
+ "role": messages[1]["role"],
105
+ "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
106
+ }
107
+ ] + messages[2:]
108
+
109
+ messages_list = [
110
+ f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
111
+ for prompt, answer in zip(messages[::2], messages[1::2])
112
+ ]
113
+ messages_list.append(
114
+ f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
115
+
116
+ return "".join(messages_list)
117
+
118
+
119
+ def main() -> None:
120
+ _ = load_dotenv(find_dotenv())
121
+
122
+ init_page()
123
+ llm = select_llm()
124
+ init_messages()
125
+
126
+ # Supervise user input
127
+ if user_input := st.chat_input("Input your question!"):
128
+ st.session_state.messages.append(HumanMessage(content=user_input))
129
+ with st.spinner("ChatGPT is typing ..."):
130
+ answer, cost = get_answer(llm, st.session_state.messages)
131
+ st.session_state.messages.append(AIMessage(content=answer))
132
+ st.session_state.costs.append(cost)
133
+
134
+ # Display chat history
135
+ messages = st.session_state.get("messages", [])
136
+ for message in messages:
137
+ if isinstance(message, AIMessage):
138
+ with st.chat_message("assistant"):
139
+ st.markdown(message.content)
140
+ elif isinstance(message, HumanMessage):
141
+ with st.chat_message("user"):
142
+ st.markdown(message.content)
143
+
144
+ costs = st.session_state.get("costs", [])
145
+ st.sidebar.markdown("## Costs")
146
+ st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
147
+ for cost in costs:
148
+ st.sidebar.markdown(f"- ${cost:.5f}")
149
+
150
+
151
+ # streamlit run app.py
152
+ if __name__ == "__main__":
153
+ main()