SIRILSAM77 commited on
Commit
39b5ca3
·
verified ·
1 Parent(s): 9015d35
Files changed (1) hide show
  1. streamlit_app.py +0 -153
streamlit_app.py DELETED
@@ -1,153 +0,0 @@
1
- # app.py
2
- from typing import List, Union
3
- from dotenv import load_dotenv, find_dotenv
4
- from langchain.callbacks import get_openai_callback
5
- from langchain.chat_models import ChatOpenAI
6
- from langchain.schema import (SystemMessage, HumanMessage, AIMessage)
7
- from langchain.llms import LlamaCpp
8
- from langchain.callbacks.manager import CallbackManager
9
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
- import streamlit as st
11
-
12
-
13
- def init_page() -> None:
14
- st.set_page_config(
15
- page_title="Personal ChatGPT"
16
- )
17
- st.header("Personal ChatGPT")
18
- st.sidebar.title("Options")
19
-
20
-
21
- def init_messages() -> None:
22
- clear_button = st.sidebar.button("Clear Conversation", key="clear")
23
- if clear_button or "messages" not in st.session_state:
24
- st.session_state.messages = [
25
- SystemMessage(
26
- content="You are a helpful AI assistant. Reply your answer in mardkown format.")
27
- ]
28
- st.session_state.costs = []
29
-
30
-
31
- def select_llm() -> Union[ChatOpenAI, LlamaCpp]:
32
- model_name = st.sidebar.radio("Choose LLM:",
33
- ("gpt-3.5-turbo-0613", "gpt-4",
34
- "llama-2-7b-chat.ggmlv3.q2_K"))
35
- temperature = st.sidebar.slider("Temperature:", min_value=0.0,
36
- max_value=1.0, value=0.0, step=0.01)
37
- if model_name.startswith("gpt-"):
38
- return ChatOpenAI(temperature=temperature, model_name=model_name)
39
- elif model_name.startswith("llama-2-"):
40
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
41
- return LlamaCpp(
42
- model_path=f"/app/models/{model_name}.bin",
43
- input={"temperature": temperature,
44
- "max_length": 2000,
45
- "top_p": 1
46
- },
47
- callback_manager=callback_manager,
48
- verbose=False, # True
49
- )
50
-
51
-
52
- def get_answer(llm, messages) -> tuple[str, float]:
53
- if isinstance(llm, ChatOpenAI):
54
- with get_openai_callback() as cb:
55
- answer = llm(messages)
56
- return answer.content, cb.total_cost
57
- if isinstance(llm, LlamaCpp):
58
- return llm(llama_v2_prompt(convert_langchainschema_to_dict(messages))), 0.0
59
-
60
-
61
- def find_role(message: Union[SystemMessage, HumanMessage, AIMessage]) -> str:
62
- """
63
- Identify role name from langchain.schema object.
64
- """
65
- if isinstance(message, SystemMessage):
66
- return "system"
67
- if isinstance(message, HumanMessage):
68
- return "user"
69
- if isinstance(message, AIMessage):
70
- return "assistant"
71
- raise TypeError("Unknown message type.")
72
-
73
-
74
- def convert_langchainschema_to_dict(
75
- messages: List[Union[SystemMessage, HumanMessage, AIMessage]]) \
76
- -> List[dict]:
77
- """
78
- Convert the chain of chat messages in list of langchain.schema format to
79
- list of dictionary format.
80
- """
81
- return [{"role": find_role(message),
82
- "content": message.content
83
- } for message in messages]
84
-
85
-
86
- def llama_v2_prompt(messages: List[dict]) -> str:
87
- """
88
- Convert the messages in list of dictionary format to Llama2 compliant format.
89
- """
90
- B_INST, E_INST = "[INST]", "[/INST]"
91
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
92
- BOS, EOS = "<s>", "</s>"
93
- DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
94
-
95
- if messages[0]["role"] != "system":
96
- messages = [
97
- {
98
- "role": "system",
99
- "content": DEFAULT_SYSTEM_PROMPT,
100
- }
101
- ] + messages
102
- messages = [
103
- {
104
- "role": messages[1]["role"],
105
- "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
106
- }
107
- ] + messages[2:]
108
-
109
- messages_list = [
110
- f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
111
- for prompt, answer in zip(messages[::2], messages[1::2])
112
- ]
113
- messages_list.append(
114
- f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
115
-
116
- return "".join(messages_list)
117
-
118
-
119
- def main() -> None:
120
- _ = load_dotenv(find_dotenv())
121
-
122
- init_page()
123
- llm = select_llm()
124
- init_messages()
125
-
126
- # Supervise user input
127
- if user_input := st.chat_input("Input your question!"):
128
- st.session_state.messages.append(HumanMessage(content=user_input))
129
- with st.spinner("ChatGPT is typing ..."):
130
- answer, cost = get_answer(llm, st.session_state.messages)
131
- st.session_state.messages.append(AIMessage(content=answer))
132
- st.session_state.costs.append(cost)
133
-
134
- # Display chat history
135
- messages = st.session_state.get("messages", [])
136
- for message in messages:
137
- if isinstance(message, AIMessage):
138
- with st.chat_message("assistant"):
139
- st.markdown(message.content)
140
- elif isinstance(message, HumanMessage):
141
- with st.chat_message("user"):
142
- st.markdown(message.content)
143
-
144
- costs = st.session_state.get("costs", [])
145
- st.sidebar.markdown("## Costs")
146
- st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
147
- for cost in costs:
148
- st.sidebar.markdown(f"- ${cost:.5f}")
149
-
150
-
151
- # streamlit run app.py
152
- if __name__ == "__main__":
153
- main()