abanm commited on
Commit
28778c3
·
verified ·
1 Parent(s): 6f4229d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -41
app.py CHANGED
@@ -1,42 +1,194 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
-
5
- st.title('Uber pickups in NYC')
6
-
7
- DATE_COLUMN = 'date/time'
8
- DATA_URL = ('https://s3-us-west-2.amazonaws.com/'
9
- 'streamlit-demo-data/uber-raw-data-sep14.csv.gz')
10
-
11
- @st.cache_data
12
- def load_data(nrows):
13
- data = pd.read_csv(DATA_URL, nrows=nrows)
14
- lowercase = lambda x: str(x).lower()
15
- data.rename(lowercase, axis='columns', inplace=True)
16
- data[DATE_COLUMN] = pd.to_datetime(data[DATE_COLUMN])
17
- return data
18
-
19
- data_load_state = st.text('Loading data...')
20
- data = load_data(10000)
21
- data_load_state.text("Done! (using st.cache)")
22
-
23
- if st.checkbox('Show raw data'):
24
- st.subheader('Raw data')
25
- st.write(data)
26
-
27
- st.subheader('Number of pickups by hour')
28
- hist_values = np.histogram(data[DATE_COLUMN].dt.hour, bins=24, range=(0,24))[0]
29
- st.bar_chart(hist_values)
30
-
31
- # Some number in the range 0-23
32
- hour_to_filter = st.slider('hour', 0, 23, 17)
33
- filtered_data = data[data[DATE_COLUMN].dt.hour == hour_to_filter]
34
-
35
- st.subheader('Map of all pickups at %s:00' % hour_to_filter)
36
- st.map(filtered_data)
37
-
38
- uploaded_file = st.file_uploader("Choose a file")
39
- if uploaded_file is not None:
40
- st.write(uploaded_file.name)
41
- bytes_data = uploaded_file.getvalue()
42
- st.write(len(bytes_data), "bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import requests
3
+ import json
4
+ import os
5
+ import datetime
6
+ from huggingface_hub import InferenceClient
7
+
8
+ # Constants
9
+ SPACE_URL = "https://z7svds7k42bwhhgm.us-east-1.aws.endpoints.huggingface.cloud"
10
+ HF_API_KEY = os.getenv("HF_API_KEY") # Retrieve the Hugging Face API key from system variables
11
+ EOS_TOKEN = "<|end|>"
12
+ CHAT_HISTORY_DIR = "chat_histories"
13
+ IMAGE_PATH = "DubsChat.png"
14
+ IMAGE_PATH_2 = "Reboot AI.png"
15
+ DUBS_PATH = "Dubs.png"
16
+
17
+ # Ensure the directory exists
18
+ try:
19
+ os.makedirs(CHAT_HISTORY_DIR, exist_ok=True)
20
+ except OSError as e:
21
+ st.error(f"Failed to create chat history directory: {e}")
22
+
23
+ # Streamlit Configurations
24
+ st.set_page_config(page_title="DUBSChat", page_icon=IMAGE_PATH, layout="wide")
25
+ st.logo(IMAGE_PATH_2,size="large")
26
+
27
+ # -------------------------
28
+ # Chat Template
29
+ # -------------------------
30
+ CHAT_TEMPLATE = """
31
+ <|system|>
32
+ You are a helpful assistant.<|end|>
33
+ {history}
34
+ <|user|>
35
+ {user_input}<|end|>
36
+ <|assistant|>
37
+ """
38
+
39
+ def format_chat_template(history, user_input):
40
+ """
41
+ Formats the chat template by combining the chat history and user input.
42
+ """
43
+ return CHAT_TEMPLATE.format(history=history, user_input=user_input)
44
+
45
+ # -------------------------
46
+ # Generate Chat History
47
+ # -------------------------
48
+ def format_chat_history(messages):
49
+ """
50
+ Converts the chat messages into a string compatible with the chat template.
51
+ Ensures no duplicate <|assistant|> tokens in the history.
52
+ """
53
+ history = ""
54
+ for message in messages:
55
+ if message["role"] == "user":
56
+ history += f"<|user|>{message['content']}<|end|>\n"
57
+ elif message["role"] == "assistant":
58
+ history += f"<|assistant|>{message['content']}<|end|>\n"
59
+ return history.strip() # Remove any trailing newlines
60
+
61
+ # -------------------------
62
+ # Utility Functions
63
+ # -------------------------
64
+ def save_chat_history(session_name, messages):
65
+ """
66
+ Save the chat history to a JSON file.
67
+ """
68
+ file_path = os.path.join(CHAT_HISTORY_DIR, f"{session_name}.json")
69
+ try:
70
+ with open(file_path, "w") as f:
71
+ json.dump(messages, f)
72
+ except IOError as e:
73
+ st.error(f"Failed to save chat history: {e}")
74
+
75
+
76
+ def load_chat_history(file_name):
77
+ """
78
+ Load the chat history from a JSON file.
79
+ """
80
+ file_path = os.path.join(CHAT_HISTORY_DIR, file_name)
81
+ try:
82
+ with open(file_path, "r") as f:
83
+ return json.load(f)
84
+ except (FileNotFoundError, json.JSONDecodeError):
85
+ st.error("Failed to load chat history. Starting with a new session.")
86
+ return []
87
+
88
+
89
+ def get_saved_sessions():
90
+ """
91
+ Get the list of saved chat sessions.
92
+ """
93
+ return [f.replace(".json", "") for f in os.listdir(CHAT_HISTORY_DIR) if f.endswith(".json")]
94
+
95
+ # -------------------------
96
+ # Sidebar Configuration
97
+ # -------------------------
98
+ with st.sidebar:
99
+ if st.button("New Chat"):
100
+ st.session_state["messages"] = [
101
+ {"role": "system", "content": "You are Dubs, a helpful assistant created my RebootAI"},
102
+ ]
103
+ st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
104
+ save_chat_history(st.session_state["session_name"], st.session_state["messages"])
105
+ st.success("Chat reset and new session started.")
106
+
107
+ saved_sessions = get_saved_sessions()
108
+ if saved_sessions:
109
+ selected_session = st.radio("Past Sessions:", saved_sessions)
110
+ if st.button("Load Session"):
111
+ st.session_state["messages"] = load_chat_history(f"{selected_session}.json")
112
+ st.session_state["session_name"] = selected_session
113
+ st.success(f"Loaded session: {selected_session}")
114
+ else:
115
+ st.write("No past sessions available.")
116
+
117
+ # -------------------------
118
+ # Chat History Initialization
119
+ # -------------------------
120
+ if "messages" not in st.session_state:
121
+ st.session_state["messages"] = [
122
+ {"role": "system", "content": "You are Dubs, a helpful assistant created my RebootAI"}
123
+ ]
124
+ if "session_name" not in st.session_state:
125
+ st.session_state["session_name"] = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
126
+
127
+ # -------------------------
128
+ # Main Chat UI
129
+ # -------------------------
130
+ st.image(IMAGE_PATH, width=250)
131
+ st.markdown("Empowering you with a Sustainable AI")
132
+
133
+ # Display existing chat history
134
+ for message in st.session_state["messages"]:
135
+ if message["role"] == "user":
136
+ st.chat_message("user").write(message["content"])
137
+ elif message["role"] == "assistant":
138
+ st.chat_message("assistant", avatar=DUBS_PATH).write(message["content"])
139
+
140
+ client = InferenceClient(SPACE_URL, token=HF_API_KEY)
141
+
142
+ # -------------------------
143
+ # Streaming Logic
144
+ # -------------------------
145
+ def stream_response(prompt_text):
146
+ """
147
+ Stream text from the HF Inference Endpoint using the InferenceClient.
148
+ Yields each partial chunk of text as it arrives.
149
+ """
150
+ gen_kwargs = {
151
+ "max_new_tokens": 1024,
152
+ "top_k": 30,
153
+ "top_p": 0.9,
154
+ "temperature": 0.2,
155
+ "repetition_penalty": 1.02,
156
+ "stop_sequences": ["<|end|>"]
157
+ }
158
+
159
+ stream = client.text_generation(prompt_text, stream=True, details=True, **gen_kwargs)
160
+
161
+ for response in stream:
162
+ if response.token.special:
163
+ continue
164
+ yield response.token.text
165
+
166
+ # -------------------------
167
+ # User Input
168
+ # -------------------------
169
+ prompt = st.chat_input()
170
+
171
+ if prompt:
172
+ # 1) Add the user's message to session state
173
+ st.session_state["messages"].append({"role": "user", "content": prompt})
174
+ st.chat_message("user").write(prompt)
175
+
176
+ # 2) Format chat history and user input for the template
177
+ chat_history = format_chat_history(st.session_state["messages"][:-1]) # Exclude the current user input
178
+ model_input = format_chat_template(chat_history, prompt)
179
+
180
+ # 3) Generate the assistant's response
181
+ with st.spinner("Dubs is thinking... Woof Woof! 🐾"):
182
+ msg = ""
183
+ with st.chat_message("assistant", avatar=DUBS_PATH):
184
+ response_stream = stream_response(model_input)
185
+ msg = st.write_stream(response_stream)
186
+
187
+ # 4) Add the assistant's response to session state
188
+ st.session_state["messages"].append({"role": "assistant", "content": msg})
189
+
190
+ # 5) Persist the updated chat history
191
+ save_chat_history(st.session_state["session_name"], st.session_state["messages"])
192
+
193
+
194
+