Spaces:
Runtime error
Runtime error
remove another comment
Browse files- wanderlust.py +38 -56
wanderlust.py
CHANGED
|
@@ -4,7 +4,6 @@ import os
|
|
| 4 |
import ipyleaflet
|
| 5 |
from openai import OpenAI, NotFoundError
|
| 6 |
from openai.types.beta import Thread
|
| 7 |
-
from openai.types.beta.threads import Run
|
| 8 |
|
| 9 |
import time
|
| 10 |
|
|
@@ -13,9 +12,7 @@ import solara
|
|
| 13 |
center_default = (0, 0)
|
| 14 |
zoom_default = 2
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
messages = solara.reactive(messages_default)
|
| 19 |
zoom_level = solara.reactive(zoom_default)
|
| 20 |
center = solara.reactive(center_default)
|
| 21 |
markers = solara.reactive([])
|
|
@@ -25,6 +22,7 @@ openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
| 25 |
model = "gpt-4-1106-preview"
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
tools = [
|
| 29 |
{
|
| 30 |
"type": "function",
|
|
@@ -80,7 +78,6 @@ tools = [
|
|
| 80 |
|
| 81 |
|
| 82 |
def update_map(longitude, latitude, zoom):
|
| 83 |
-
print("update_map", longitude, latitude, zoom)
|
| 84 |
center.set((latitude, longitude))
|
| 85 |
zoom_level.set(zoom)
|
| 86 |
return "Map updated"
|
|
@@ -111,12 +108,9 @@ def ai_call(tool_call):
|
|
| 111 |
|
| 112 |
@solara.component
|
| 113 |
def Map():
|
| 114 |
-
print("Map", zoom_level.value, center.value, markers.value)
|
| 115 |
ipyleaflet.Map.element( # type: ignore
|
| 116 |
zoom=zoom_level.value,
|
| 117 |
-
# on_zoom=zoom_level.set,
|
| 118 |
center=center.value,
|
| 119 |
-
# on_center=center.set,
|
| 120 |
scroll_wheel_zoom=True,
|
| 121 |
layers=[
|
| 122 |
ipyleaflet.TileLayer.element(url=url),
|
|
@@ -134,7 +128,6 @@ def ChatInterface():
|
|
| 134 |
run_id: solara.Reactive[str] = solara.use_reactive(None)
|
| 135 |
|
| 136 |
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
|
| 137 |
-
print("thread id:", thread.id)
|
| 138 |
|
| 139 |
def add_message(value: str):
|
| 140 |
if value == "":
|
|
@@ -149,7 +142,6 @@ def ChatInterface():
|
|
| 149 |
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
|
| 150 |
tools=tools,
|
| 151 |
).id
|
| 152 |
-
print("Run id:", run_id.value)
|
| 153 |
|
| 154 |
def poll():
|
| 155 |
if not run_id.value:
|
|
@@ -159,7 +151,8 @@ def ChatInterface():
|
|
| 159 |
try:
|
| 160 |
run = openai.beta.threads.runs.retrieve(
|
| 161 |
run_id.value, thread_id=thread.id
|
| 162 |
-
)
|
|
|
|
| 163 |
except NotFoundError:
|
| 164 |
continue
|
| 165 |
if run.status == "requires_action":
|
|
@@ -167,6 +160,7 @@ def ChatInterface():
|
|
| 167 |
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
| 168 |
tool_output = ai_call(tool_call)
|
| 169 |
tool_outputs.append(tool_output)
|
|
|
|
| 170 |
openai.beta.threads.runs.submit_tool_outputs(
|
| 171 |
thread_id=thread.id,
|
| 172 |
run_id=run_id.value,
|
|
@@ -182,27 +176,10 @@ def ChatInterface():
|
|
| 182 |
run_id.set(None)
|
| 183 |
completed = True
|
| 184 |
time.sleep(0.1)
|
| 185 |
-
retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
|
| 186 |
-
messages.set(retrieved_messages.data)
|
| 187 |
|
| 188 |
result = solara.use_thread(poll, dependencies=[run_id.value])
|
| 189 |
|
| 190 |
-
|
| 191 |
-
print("handle", message)
|
| 192 |
-
messages = []
|
| 193 |
-
if message.role == "assistant":
|
| 194 |
-
tools_calls = message.get("tool_calls", [])
|
| 195 |
-
for tool_call in tools_calls:
|
| 196 |
-
messages.append(ai_call(tool_call))
|
| 197 |
-
return messages
|
| 198 |
-
|
| 199 |
-
def handle_initial():
|
| 200 |
-
print("handle initial", messages.value)
|
| 201 |
-
for message in messages.value:
|
| 202 |
-
handle_message(message)
|
| 203 |
-
|
| 204 |
-
solara.use_effect(handle_initial, [])
|
| 205 |
-
# result = solara.use_thread(ask, dependencies=[messages.value])
|
| 206 |
with solara.Column(
|
| 207 |
classes=["chat-interface"],
|
| 208 |
):
|
|
@@ -214,16 +191,25 @@ def ChatInterface():
|
|
| 214 |
"overflow-y": "auto",
|
| 215 |
"height": "100px",
|
| 216 |
"flex-direction": "column-reverse",
|
| 217 |
-
}
|
|
|
|
| 218 |
):
|
| 219 |
for message in reversed(messages.value):
|
| 220 |
with solara.Row(style={"align-items": "flex-start"}):
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
solara.Text(
|
| 223 |
message.content[0].text.value,
|
| 224 |
classes=["chat-message", "user-message"],
|
| 225 |
)
|
| 226 |
-
assert len(message.content) == 1
|
| 227 |
elif message.role == "assistant":
|
| 228 |
if message.content[0].text.value:
|
| 229 |
solara.v.Icon(
|
|
@@ -246,8 +232,6 @@ def ChatInterface():
|
|
| 246 |
repr(message),
|
| 247 |
classes=["chat-message", "assistant-message"],
|
| 248 |
)
|
| 249 |
-
elif message["role"] == "tool":
|
| 250 |
-
pass # no need to display
|
| 251 |
else:
|
| 252 |
solara.v.Icon(
|
| 253 |
children=["mdi-compass-outline"],
|
|
@@ -272,21 +256,6 @@ def ChatInterface():
|
|
| 272 |
|
| 273 |
@solara.component
|
| 274 |
def Page():
|
| 275 |
-
reset_counter, set_reset_counter = solara.use_state(0)
|
| 276 |
-
print("reset", reset_counter, f"chat-{reset_counter}")
|
| 277 |
-
|
| 278 |
-
def reset_ui():
|
| 279 |
-
set_reset_counter(reset_counter + 1)
|
| 280 |
-
|
| 281 |
-
def save():
|
| 282 |
-
with open("log.json", "w") as f:
|
| 283 |
-
json.dump(messages.value, f)
|
| 284 |
-
|
| 285 |
-
def load():
|
| 286 |
-
with open("log.json", "r") as f:
|
| 287 |
-
messages.set(json.load(f))
|
| 288 |
-
reset_ui()
|
| 289 |
-
|
| 290 |
with solara.Column(
|
| 291 |
classes=["ui-container"],
|
| 292 |
gap="5vh",
|
|
@@ -299,16 +268,12 @@ def Page():
|
|
| 299 |
unsafe_innerHTML="Wanderlust",
|
| 300 |
style={"display": "inline-block"},
|
| 301 |
)
|
| 302 |
-
# with solara.Row(gap="10px"):
|
| 303 |
-
# solara.Button("Save", on_click=save)
|
| 304 |
-
# solara.Button("Load", on_click=load)
|
| 305 |
-
# solara.Button("Soft reset", on_click=reset_ui)
|
| 306 |
with solara.Row(
|
| 307 |
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
|
| 308 |
):
|
| 309 |
-
ChatInterface()
|
| 310 |
with solara.Column(classes=["map-container"]):
|
| 311 |
-
Map()
|
| 312 |
|
| 313 |
solara.Style(
|
| 314 |
"""
|
|
@@ -335,13 +300,30 @@ def Page():
|
|
| 335 |
height: 100%;
|
| 336 |
width: 38vw;
|
| 337 |
justify-content: center;
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
}
|
| 340 |
.map-container{
|
| 341 |
width: 50vw;
|
| 342 |
height: 100%;
|
| 343 |
justify-content: center;
|
| 344 |
}
|
|
|
|
|
|
|
|
|
|
| 345 |
@media screen and (max-aspect-ratio: 1/1) {
|
| 346 |
.ui-container{
|
| 347 |
padding: 30px;
|
|
|
|
| 4 |
import ipyleaflet
|
| 5 |
from openai import OpenAI, NotFoundError
|
| 6 |
from openai.types.beta import Thread
|
|
|
|
| 7 |
|
| 8 |
import time
|
| 9 |
|
|
|
|
| 12 |
center_default = (0, 0)
|
| 13 |
zoom_default = 2
|
| 14 |
|
| 15 |
+
messages = solara.reactive([])
|
|
|
|
|
|
|
| 16 |
zoom_level = solara.reactive(zoom_default)
|
| 17 |
center = solara.reactive(center_default)
|
| 18 |
markers = solara.reactive([])
|
|
|
|
| 22 |
model = "gpt-4-1106-preview"
|
| 23 |
|
| 24 |
|
| 25 |
+
# Declare tools for openai assistant to use
|
| 26 |
tools = [
|
| 27 |
{
|
| 28 |
"type": "function",
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def update_map(longitude, latitude, zoom):
|
|
|
|
| 81 |
center.set((latitude, longitude))
|
| 82 |
zoom_level.set(zoom)
|
| 83 |
return "Map updated"
|
|
|
|
| 108 |
|
| 109 |
@solara.component
|
| 110 |
def Map():
|
|
|
|
| 111 |
ipyleaflet.Map.element( # type: ignore
|
| 112 |
zoom=zoom_level.value,
|
|
|
|
| 113 |
center=center.value,
|
|
|
|
| 114 |
scroll_wheel_zoom=True,
|
| 115 |
layers=[
|
| 116 |
ipyleaflet.TileLayer.element(url=url),
|
|
|
|
| 128 |
run_id: solara.Reactive[str] = solara.use_reactive(None)
|
| 129 |
|
| 130 |
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
|
|
|
|
| 131 |
|
| 132 |
def add_message(value: str):
|
| 133 |
if value == "":
|
|
|
|
| 142 |
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
|
| 143 |
tools=tools,
|
| 144 |
).id
|
|
|
|
| 145 |
|
| 146 |
def poll():
|
| 147 |
if not run_id.value:
|
|
|
|
| 151 |
try:
|
| 152 |
run = openai.beta.threads.runs.retrieve(
|
| 153 |
run_id.value, thread_id=thread.id
|
| 154 |
+
)
|
| 155 |
+
# Above will raise NotFoundError when run creation is still in progress
|
| 156 |
except NotFoundError:
|
| 157 |
continue
|
| 158 |
if run.status == "requires_action":
|
|
|
|
| 160 |
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
| 161 |
tool_output = ai_call(tool_call)
|
| 162 |
tool_outputs.append(tool_output)
|
| 163 |
+
messages.set([*messages.value, tool_output])
|
| 164 |
openai.beta.threads.runs.submit_tool_outputs(
|
| 165 |
thread_id=thread.id,
|
| 166 |
run_id=run_id.value,
|
|
|
|
| 176 |
run_id.set(None)
|
| 177 |
completed = True
|
| 178 |
time.sleep(0.1)
|
|
|
|
|
|
|
| 179 |
|
| 180 |
result = solara.use_thread(poll, dependencies=[run_id.value])
|
| 181 |
|
| 182 |
+
# Create DOM for chat interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
with solara.Column(
|
| 184 |
classes=["chat-interface"],
|
| 185 |
):
|
|
|
|
| 191 |
"overflow-y": "auto",
|
| 192 |
"height": "100px",
|
| 193 |
"flex-direction": "column-reverse",
|
| 194 |
+
},
|
| 195 |
+
classes=["chat-box"],
|
| 196 |
):
|
| 197 |
for message in reversed(messages.value):
|
| 198 |
with solara.Row(style={"align-items": "flex-start"}):
|
| 199 |
+
# Catch "messages" that are actually tool calls
|
| 200 |
+
if isinstance(message, dict):
|
| 201 |
+
icon = (
|
| 202 |
+
"mdi-map"
|
| 203 |
+
if message["output"] == "Map updated"
|
| 204 |
+
else "mdi-map-marker"
|
| 205 |
+
)
|
| 206 |
+
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
|
| 207 |
+
solara.Markdown(message["output"])
|
| 208 |
+
elif message.role == "user":
|
| 209 |
solara.Text(
|
| 210 |
message.content[0].text.value,
|
| 211 |
classes=["chat-message", "user-message"],
|
| 212 |
)
|
|
|
|
| 213 |
elif message.role == "assistant":
|
| 214 |
if message.content[0].text.value:
|
| 215 |
solara.v.Icon(
|
|
|
|
| 232 |
repr(message),
|
| 233 |
classes=["chat-message", "assistant-message"],
|
| 234 |
)
|
|
|
|
|
|
|
| 235 |
else:
|
| 236 |
solara.v.Icon(
|
| 237 |
children=["mdi-compass-outline"],
|
|
|
|
| 256 |
|
| 257 |
@solara.component
|
| 258 |
def Page():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
with solara.Column(
|
| 260 |
classes=["ui-container"],
|
| 261 |
gap="5vh",
|
|
|
|
| 268 |
unsafe_innerHTML="Wanderlust",
|
| 269 |
style={"display": "inline-block"},
|
| 270 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
with solara.Row(
|
| 272 |
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
|
| 273 |
):
|
| 274 |
+
ChatInterface()
|
| 275 |
with solara.Column(classes=["map-container"]):
|
| 276 |
+
Map()
|
| 277 |
|
| 278 |
solara.Style(
|
| 279 |
"""
|
|
|
|
| 300 |
height: 100%;
|
| 301 |
width: 38vw;
|
| 302 |
justify-content: center;
|
| 303 |
+
position: relative;
|
| 304 |
+
}
|
| 305 |
+
.chat-interface:after {
|
| 306 |
+
content: "";
|
| 307 |
+
position: absolute;
|
| 308 |
+
z-index: 1;
|
| 309 |
+
top: 0;
|
| 310 |
+
left: 0;
|
| 311 |
+
pointer-events: none;
|
| 312 |
+
background-image: linear-gradient(to top, rgba(255,255,255,0), rgba(255,255,255, 1) 100%);
|
| 313 |
+
width: 100%;
|
| 314 |
+
height: 15%;
|
| 315 |
+
}
|
| 316 |
+
.chat-box > :last-child{
|
| 317 |
+
padding-top: 7.5vh;
|
| 318 |
}
|
| 319 |
.map-container{
|
| 320 |
width: 50vw;
|
| 321 |
height: 100%;
|
| 322 |
justify-content: center;
|
| 323 |
}
|
| 324 |
+
.user-message{
|
| 325 |
+
font-weight: bold;
|
| 326 |
+
}
|
| 327 |
@media screen and (max-aspect-ratio: 1/1) {
|
| 328 |
.ui-container{
|
| 329 |
padding: 30px;
|