Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| import ipyleaflet | |
| from openai import NotFoundError, OpenAI | |
| from openai.types.beta import Thread | |
| import solara | |
| HERE = Path(__file__).parent | |
| center_default = (0, 0) | |
| zoom_default = 2 | |
| messages = solara.reactive([]) | |
| zoom_level = solara.reactive(zoom_default) | |
| center = solara.reactive(center_default) | |
| markers = solara.reactive([]) | |
| url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url() | |
| openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| model = "gpt-4-1106-preview" | |
| app_style = (HERE / "style.css").read_text() | |
| # Declare tools for openai assistant to use | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "update_map", | |
| "description": "Update map to center on a particular location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "longitude": { | |
| "type": "number", | |
| "description": "Longitude of the location to center the map on", | |
| }, | |
| "latitude": { | |
| "type": "number", | |
| "description": "Latitude of the location to center the map on", | |
| }, | |
| "zoom": { | |
| "type": "integer", | |
| "description": "Zoom level of the map", | |
| }, | |
| }, | |
| "required": ["longitude", "latitude", "zoom"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "add_marker", | |
| "description": "Add marker to the map", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "longitude": { | |
| "type": "number", | |
| "description": "Longitude of the location to the marker", | |
| }, | |
| "latitude": { | |
| "type": "number", | |
| "description": "Latitude of the location to the marker", | |
| }, | |
| "label": { | |
| "type": "string", | |
| "description": "Text to display on the marker", | |
| }, | |
| }, | |
| "required": ["longitude", "latitude", "label"], | |
| }, | |
| }, | |
| }, | |
| ] | |
| def update_map(longitude, latitude, zoom): | |
| center.set((latitude, longitude)) | |
| zoom_level.set(zoom) | |
| return "Map updated" | |
| def add_marker(longitude, latitude, label): | |
| markers.set(markers.value + [{"location": (latitude, longitude), "label": label}]) | |
| return "Marker added" | |
| functions = { | |
| "update_map": update_map, | |
| "add_marker": add_marker, | |
| } | |
| def assistant_tool_call(tool_call): | |
| # actually executes the tool call the OpenAI assistant wants to perform | |
| function = tool_call.function | |
| name = function.name | |
| arguments = json.loads(function.arguments) | |
| return_value = functions[name](**arguments) | |
| tool_outputs = { | |
| "tool_call_id": tool_call.id, | |
| "output": return_value, | |
| } | |
| return tool_outputs | |
| def Map(): | |
| ipyleaflet.Map.element( # type: ignore | |
| zoom=zoom_level.value, | |
| center=center.value, | |
| scroll_wheel_zoom=True, | |
| layers=[ | |
| ipyleaflet.TileLayer.element(url=url), | |
| *[ | |
| ipyleaflet.Marker.element(location=k["location"], draggable=False) | |
| for k in markers.value | |
| ], | |
| ], | |
| ) | |
| def ChatMessage(message): | |
| with solara.Row(style={"align-items": "flex-start"}): | |
| # Catch "messages" that are actually tool calls | |
| if isinstance(message, dict): | |
| icon = "mdi-map" if message["output"] == "Map updated" else "mdi-map-marker" | |
| solara.v.Icon(children=[icon], style_="padding-top: 10px;") | |
| solara.Markdown(message["output"]) | |
| elif message.role == "user": | |
| solara.Text(message.content[0].text.value, style={"font-weight": "bold;"}) | |
| elif message.role == "assistant": | |
| if message.content[0].text.value: | |
| solara.v.Icon( | |
| children=["mdi-compass-outline"], style_="padding-top: 10px;" | |
| ) | |
| solara.Markdown(message.content[0].text.value) | |
| elif message.content.tool_calls: | |
| solara.v.Icon(children=["mdi-map"], style_="padding-top: 10px;") | |
| solara.Markdown("*Calling map functions*") | |
| else: | |
| solara.v.Icon( | |
| children=["mdi-compass-outline"], style_="padding-top: 10px;" | |
| ) | |
| solara.Preformatted(repr(message)) | |
| else: | |
| solara.v.Icon(children=["mdi-compass-outline"], style_="padding-top: 10px;") | |
| solara.Preformatted(repr(message)) | |
| def ChatBox(children=[]): | |
| # this uses a flexbox with column-reverse to reverse the order of the messages | |
| # if we now also reverse the order of the messages, we get the correct order | |
| # but the scroll position is at the bottom of the container automatically | |
| with solara.Column(style={"flex-grow": "1"}): | |
| solara.Style( | |
| """ | |
| .chat-box > :last-child{ | |
| padding-top: 7.5vh; | |
| } | |
| """ | |
| ) | |
| # The height works effectively as `min-height`, since flex will grow the container to fill the available space | |
| solara.Column( | |
| style={ | |
| "flex-grow": "1", | |
| "overflow-y": "auto", | |
| "height": "100px", | |
| "flex-direction": "column-reverse", | |
| }, | |
| classes=["chat-box"], | |
| children=list(reversed(children)), | |
| ) | |
| def ChatInterface(): | |
| prompt = solara.use_reactive("") | |
| run_id: solara.Reactive[str] = solara.use_reactive(None) | |
| # Create a thread to hold the conversation only once when this component is created | |
| thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[]) | |
| def add_message(value: str): | |
| if value == "": | |
| return | |
| prompt.set("") | |
| new_message = openai.beta.threads.messages.create( | |
| thread_id=thread.id, content=value, role="user" | |
| ) | |
| messages.set([*messages.value, new_message]) | |
| # this creates a new run for the thread | |
| # also also triggers a rerender (since run_id.value changes) | |
| # which will trigger the poll function blow to start in a thread | |
| run_id.value = openai.beta.threads.runs.create( | |
| thread_id=thread.id, | |
| assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH", | |
| tools=tools, | |
| ).id | |
| def poll(): | |
| if not run_id.value: | |
| return | |
| completed = False | |
| while not completed: | |
| try: | |
| run = openai.beta.threads.runs.retrieve( | |
| run_id.value, thread_id=thread.id | |
| ) | |
| # Above will raise NotFoundError when run creation is still in progress | |
| except NotFoundError: | |
| continue | |
| if run.status == "requires_action": | |
| tool_outputs = [] | |
| for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
| tool_output = assistant_tool_call(tool_call) | |
| tool_outputs.append(tool_output) | |
| messages.set([*messages.value, tool_output]) | |
| openai.beta.threads.runs.submit_tool_outputs( | |
| thread_id=thread.id, | |
| run_id=run_id.value, | |
| tool_outputs=tool_outputs, | |
| ) | |
| if run.status == "completed": | |
| messages.set( | |
| [ | |
| *messages.value, | |
| openai.beta.threads.messages.list(thread.id).data[0], | |
| ] | |
| ) | |
| run_id.set(None) | |
| completed = True | |
| time.sleep(0.1) | |
| # run/restart a thread any time the run_id changes | |
| result = solara.use_thread(poll, dependencies=[run_id.value]) | |
| # Create DOM for chat interface | |
| with solara.Column(classes=["chat-interface"]): | |
| if len(messages.value) > 0: | |
| with ChatBox(): | |
| for message in messages.value: | |
| ChatMessage(message) | |
| with solara.Column(): | |
| solara.InputText( | |
| label="Where do you want to go?" | |
| if len(messages.value) == 0 | |
| else "Ask more question here", | |
| value=prompt, | |
| style={"flex-grow": "1"}, | |
| on_value=add_message, | |
| disabled=result.state == solara.ResultState.RUNNING, | |
| ) | |
| solara.ProgressLinear(result.state == solara.ResultState.RUNNING) | |
| if result.state == solara.ResultState.ERROR: | |
| solara.Error(repr(result.error)) | |
| def Page(): | |
| with solara.Column( | |
| classes=["ui-container"], | |
| gap="5vh", | |
| ): | |
| with solara.Row(justify="space-between"): | |
| with solara.Row(gap="10px", style={"align-items": "center"}): | |
| solara.v.Icon(children=["mdi-compass-rose"], size="36px") | |
| solara.HTML( | |
| tag="h2", | |
| unsafe_innerHTML="Wanderlust", | |
| style={"display": "inline-block"}, | |
| ) | |
| with solara.Row( | |
| gap="30px", | |
| style={"align-items": "center"}, | |
| classes=["link-container"], | |
| justify="end", | |
| ): | |
| with solara.Row(gap="5px", style={"align-items": "center"}): | |
| solara.Text("Source Code:", style="font-weight: bold;") | |
| # target="_blank" links are still easiest to do via ipyvuetify | |
| with solara.v.Btn( | |
| icon=True, | |
| tag="a", | |
| attributes={ | |
| "href": "https://github.com/widgetti/wanderlust", | |
| "title": "Wanderlust Source Code", | |
| "target": "_blank", | |
| }, | |
| ): | |
| solara.v.Icon(children=["mdi-github-circle"]) | |
| with solara.Row(gap="5px", style={"align-items": "center"}): | |
| solara.Text("Powered by Solara:", style="font-weight: bold;") | |
| with solara.v.Btn( | |
| icon=True, | |
| tag="a", | |
| attributes={ | |
| "href": "https://solara.dev/", | |
| "title": "Solara", | |
| "target": "_blank", | |
| }, | |
| ): | |
| solara.HTML( | |
| tag="img", | |
| attributes={ | |
| "src": "https://solara.dev/static/public/logo.svg", | |
| "width": "24px", | |
| }, | |
| ) | |
| with solara.v.Btn( | |
| icon=True, | |
| tag="a", | |
| attributes={ | |
| "href": "https://github.com/widgetti/solara", | |
| "title": "Solara Source Code", | |
| "target": "_blank", | |
| }, | |
| ): | |
| solara.v.Icon(children=["mdi-github-circle"]) | |
| with solara.Row( | |
| justify="space-between", style={"flex-grow": "1"}, classes=["container-row"] | |
| ): | |
| ChatInterface() | |
| with solara.Column(classes=["map-container"]): | |
| Map() | |
| solara.Style(app_style) | |