Spaces:
No application file
No application file
cleaning up gradio messages and showing player profile
Browse files- api/event_handlers/gradio_handler.py +41 -13
- api/event_handlers/print_handler.py +23 -1
- api/scripts/workflow_playground.py +2 -1
- api/server_gradio.py +6 -4
- api/tools/player_search.py +18 -2
api/event_handlers/gradio_handler.py
CHANGED
|
@@ -5,6 +5,19 @@ from langchain_core.outputs.llm_result import LLMResult
|
|
| 5 |
from typing import List
|
| 6 |
from langchain_core.messages import BaseMessage
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class GradioEventHandler(AsyncCallbackHandler):
|
| 10 |
"""
|
|
@@ -33,32 +46,43 @@ class GradioEventHandler(AsyncCallbackHandler):
|
|
| 33 |
)
|
| 34 |
|
| 35 |
async def on_chat_model_start(self, *args, **kwargs):
|
| 36 |
-
|
| 37 |
-
self.
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
async def on_llm_new_token(self, token: str, **kwargs):
|
| 45 |
if token:
|
| 46 |
self.queue.put(token)
|
| 47 |
|
| 48 |
async def on_llm_end(self, result: LLMResult, *args, **kwargs):
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
async def on_tool_end(self, output: any, **kwargs):
|
| 53 |
-
print(f"\n{Fore.CYAN}[TOOL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
| 56 |
-
self.info_box(
|
| 57 |
|
| 58 |
async def on_workflow_end(self, state, *args, **kwargs):
|
| 59 |
print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
@staticmethod
|
| 64 |
def is_chat_stream_end(result: LLMResult) -> bool:
|
|
@@ -67,3 +91,7 @@ class GradioEventHandler(AsyncCallbackHandler):
|
|
| 67 |
return bool(content and content.strip())
|
| 68 |
except (IndexError, AttributeError):
|
| 69 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from typing import List
|
| 6 |
from langchain_core.messages import BaseMessage
|
| 7 |
|
| 8 |
+
image_base = """
|
| 9 |
+
<img
|
| 10 |
+
src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/players_pics/{filename}"
|
| 11 |
+
style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
| 12 |
+
/>
|
| 13 |
+
"""
|
| 14 |
+
team_image_map = {
|
| 15 |
+
'everglade-fc': 'Everglade_FC',
|
| 16 |
+
'fraser-valley-united': 'Fraser_Valley_United',
|
| 17 |
+
'tierra-alta-fc': 'Tierra_Alta_FC',
|
| 18 |
+
'yucatan-force': 'Yucatan_Force',
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
|
| 22 |
class GradioEventHandler(AsyncCallbackHandler):
|
| 23 |
"""
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
async def on_chat_model_start(self, *args, **kwargs):
|
| 49 |
+
pass
|
| 50 |
+
# self.info_box('[CHAT START]')
|
| 51 |
+
# self.ots_box("""
|
| 52 |
+
# <img
|
| 53 |
+
# src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/landing.png"
|
| 54 |
+
# style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
| 55 |
+
# />
|
| 56 |
+
# """)
|
| 57 |
|
| 58 |
async def on_llm_new_token(self, token: str, **kwargs):
|
| 59 |
if token:
|
| 60 |
self.queue.put(token)
|
| 61 |
|
| 62 |
async def on_llm_end(self, result: LLMResult, *args, **kwargs):
|
| 63 |
+
pass
|
| 64 |
+
# if self.is_chat_stream_end(result):
|
| 65 |
+
# self.queue.put(None)
|
| 66 |
|
| 67 |
async def on_tool_end(self, output: any, **kwargs):
|
| 68 |
+
print(f"\n{Fore.CYAN}[TOOL END] {output}{Style.RESET_ALL}")
|
| 69 |
+
for doc in output:
|
| 70 |
+
if True:#doc.metadata.get("show_profile_card"):
|
| 71 |
+
img = image_base.format(filename=self.get_image_filename(doc))
|
| 72 |
+
print(f"\n{Fore.YELLOW}[TOOL END] {img}{Style.RESET_ALL}")
|
| 73 |
+
self.ots_box(img)
|
| 74 |
+
break
|
| 75 |
+
# else:
|
| 76 |
+
# self.info_box(doc)
|
| 77 |
|
| 78 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
| 79 |
+
self.info_box(input.get("name", "[TOOL START]"))
|
| 80 |
|
| 81 |
async def on_workflow_end(self, state, *args, **kwargs):
|
| 82 |
print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
|
| 83 |
+
self.queue.put(None)
|
| 84 |
+
# for msg in state["messages"]:
|
| 85 |
+
# print(f'{Fore.YELLOW}{msg.content}{Style.RESET_ALL}')
|
| 86 |
|
| 87 |
@staticmethod
|
| 88 |
def is_chat_stream_end(result: LLMResult) -> bool:
|
|
|
|
| 91 |
return bool(content and content.strip())
|
| 92 |
except (IndexError, AttributeError):
|
| 93 |
return False
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def get_image_filename(doc):
|
| 97 |
+
return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
|
api/event_handlers/print_handler.py
CHANGED
|
@@ -4,6 +4,19 @@ from langchain_core.outputs.llm_result import LLMResult
|
|
| 4 |
from typing import List
|
| 5 |
from langchain_core.messages import BaseMessage
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
class PrintEventHandler(AsyncCallbackHandler):
|
| 8 |
"""
|
| 9 |
Example async event handler: prints streaming tokens and tool results.
|
|
@@ -26,7 +39,12 @@ class PrintEventHandler(AsyncCallbackHandler):
|
|
| 26 |
print('\n[END]')
|
| 27 |
|
| 28 |
async def on_tool_end(self, output: any, **kwargs):
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
| 32 |
print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
|
|
@@ -42,6 +60,10 @@ class PrintEventHandler(AsyncCallbackHandler):
|
|
| 42 |
except (IndexError, AttributeError):
|
| 43 |
return False
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# def __getattribute__(self, name):
|
| 46 |
# attr = super().__getattribute__(name)
|
| 47 |
# if callable(attr) and name.startswith("on_"):
|
|
|
|
| 4 |
from typing import List
|
| 5 |
from langchain_core.messages import BaseMessage
|
| 6 |
|
| 7 |
+
image_base = """
|
| 8 |
+
<img
|
| 9 |
+
src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/{filename}"
|
| 10 |
+
style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
| 11 |
+
/>
|
| 12 |
+
"""
|
| 13 |
+
team_image_map = {
|
| 14 |
+
'everglade-fc': 'Everglade_FC',
|
| 15 |
+
'fraser-valley-united': 'Fraser_Valley_United',
|
| 16 |
+
'tierra-alta-fc': 'Tierra_Alta_FC',
|
| 17 |
+
'yucatan-force': 'Yucatan_Force',
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
class PrintEventHandler(AsyncCallbackHandler):
|
| 21 |
"""
|
| 22 |
Example async event handler: prints streaming tokens and tool results.
|
|
|
|
| 39 |
print('\n[END]')
|
| 40 |
|
| 41 |
async def on_tool_end(self, output: any, **kwargs):
|
| 42 |
+
for doc in output:
|
| 43 |
+
if doc.metadata.get("show_profile_card"):
|
| 44 |
+
img = image_base.format(filename=self.get_image_filename(doc))
|
| 45 |
+
print(f"\n{Fore.CYAN}[TOOL RESULT] {img}{Style.RESET_ALL}")
|
| 46 |
+
else:
|
| 47 |
+
print(f"\n{Fore.CYAN}[TOOL RESULT] {doc}{Style.RESET_ALL}")
|
| 48 |
|
| 49 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
| 50 |
print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
|
|
|
|
| 60 |
except (IndexError, AttributeError):
|
| 61 |
return False
|
| 62 |
|
| 63 |
+
@staticmethod
|
| 64 |
+
def get_image_filename(doc):
|
| 65 |
+
return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
|
| 66 |
+
|
| 67 |
# def __getattribute__(self, name):
|
| 68 |
# attr = super().__getattribute__(name)
|
| 69 |
# if callable(attr) and name.startswith("on_"):
|
api/scripts/workflow_playground.py
CHANGED
|
@@ -36,8 +36,9 @@ workflow_bundle, state = build_workflow_with_state(
|
|
| 36 |
last_name="Bigly",
|
| 37 |
persona="Casual Fan",
|
| 38 |
messages=[
|
| 39 |
-
HumanMessage(content="tell me about some players in everglade fc"),
|
| 40 |
# HumanMessage(content="tell me about the league")
|
|
|
|
| 41 |
],
|
| 42 |
)
|
| 43 |
|
|
|
|
| 36 |
last_name="Bigly",
|
| 37 |
persona="Casual Fan",
|
| 38 |
messages=[
|
| 39 |
+
# HumanMessage(content="tell me about some players in everglade fc"),
|
| 40 |
# HumanMessage(content="tell me about the league")
|
| 41 |
+
HumanMessage(content="tell me about Ryan Martinez of everglade fc")
|
| 42 |
],
|
| 43 |
)
|
| 44 |
|
api/server_gradio.py
CHANGED
|
@@ -114,13 +114,13 @@ def submit_helper(state, handler, user_query):
|
|
| 114 |
gr.Info(token["message"])
|
| 115 |
continue
|
| 116 |
if token["type"] == "ots":
|
|
|
|
| 117 |
state.ots_content = ots_default.format(content=token["message"])
|
| 118 |
state = AppState(**state.model_dump())
|
| 119 |
-
yield state, result
|
| 120 |
continue
|
| 121 |
result += token
|
| 122 |
yield state, result
|
| 123 |
-
|
| 124 |
state.history.append(AIMessage(content=result))
|
| 125 |
|
| 126 |
### Interface ###
|
|
@@ -210,12 +210,14 @@ with gr.Blocks() as demo:
|
|
| 210 |
|
| 211 |
@submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
|
| 212 |
def submit(state, user_query):
|
| 213 |
-
user_query = user_query or "tell me about some players in everglade fc"
|
|
|
|
| 214 |
yield from submit_helper(state, handler, user_query)
|
| 215 |
|
| 216 |
@user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
|
| 217 |
def user_query_change(state, user_query):
|
| 218 |
-
user_query = user_query or "tell me about some players in everglade fc"
|
|
|
|
| 219 |
yield from submit_helper(state, handler, user_query)
|
| 220 |
|
| 221 |
@persona.change(inputs=[persona, state], outputs=[persona_disp])
|
|
|
|
| 114 |
gr.Info(token["message"])
|
| 115 |
continue
|
| 116 |
if token["type"] == "ots":
|
| 117 |
+
print('OTS: ' + token["message"])
|
| 118 |
state.ots_content = ots_default.format(content=token["message"])
|
| 119 |
state = AppState(**state.model_dump())
|
|
|
|
| 120 |
continue
|
| 121 |
result += token
|
| 122 |
yield state, result
|
| 123 |
+
|
| 124 |
state.history.append(AIMessage(content=result))
|
| 125 |
|
| 126 |
### Interface ###
|
|
|
|
| 210 |
|
| 211 |
@submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
|
| 212 |
def submit(state, user_query):
|
| 213 |
+
# user_query = user_query or "tell me about some players in everglade fc"
|
| 214 |
+
user_query = user_query or "tell me about Ryan Martinez of everglade fc"
|
| 215 |
yield from submit_helper(state, handler, user_query)
|
| 216 |
|
| 217 |
@user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
|
| 218 |
def user_query_change(state, user_query):
|
| 219 |
+
# user_query = user_query or "tell me about some players in everglade fc"
|
| 220 |
+
user_query = user_query or "tell me about Ryan Martinez of everglade fc"
|
| 221 |
yield from submit_helper(state, handler, user_query)
|
| 222 |
|
| 223 |
@persona.change(inputs=[persona, state], outputs=[persona_disp])
|
api/tools/player_search.py
CHANGED
|
@@ -31,6 +31,10 @@ class PlayerSearchSchema(BaseModel):
|
|
| 31 |
" • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
|
| 32 |
" • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
|
| 33 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
class PlayerSearchTool(BaseTool):
|
|
@@ -43,22 +47,34 @@ class PlayerSearchTool(BaseTool):
|
|
| 43 |
|
| 44 |
def _run(self,
|
| 45 |
query: str,
|
|
|
|
| 46 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 47 |
) -> List[Document]:
|
| 48 |
k = 5 if query[0] == "*" else 3
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
query,
|
| 51 |
k=k,
|
| 52 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
| 53 |
)
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
async def _arun(self,
|
| 56 |
query: str,
|
|
|
|
| 57 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 58 |
) -> List[Document]:
|
| 59 |
k = 5 if query[0] == "*" else 3
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
query,
|
| 62 |
k=k,
|
| 63 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
| 64 |
)
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
" • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
|
| 32 |
" • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
|
| 33 |
))
|
| 34 |
+
show_profile_card: bool = Field(description=(
|
| 35 |
+
"If true, only the best-matching player will be returned. The UI will display a player profile card for this result, in addition to the LLM's response."
|
| 36 |
+
"The LLM should use this flag when the user expects a single, specific player and a card UI."
|
| 37 |
+
))
|
| 38 |
|
| 39 |
|
| 40 |
class PlayerSearchTool(BaseTool):
|
|
|
|
| 47 |
|
| 48 |
def _run(self,
|
| 49 |
query: str,
|
| 50 |
+
show_profile_card: bool = False,
|
| 51 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 52 |
) -> List[Document]:
|
| 53 |
k = 5 if query[0] == "*" else 3
|
| 54 |
+
if show_profile_card:
|
| 55 |
+
k = 1
|
| 56 |
+
results = vector_store.similarity_search(
|
| 57 |
query,
|
| 58 |
k=k,
|
| 59 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
| 60 |
)
|
| 61 |
+
for result in results:
|
| 62 |
+
result.metadata["show_profile_card"] = show_profile_card
|
| 63 |
+
return results
|
| 64 |
|
| 65 |
async def _arun(self,
|
| 66 |
query: str,
|
| 67 |
+
show_profile_card: bool = False,
|
| 68 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 69 |
) -> List[Document]:
|
| 70 |
k = 5 if query[0] == "*" else 3
|
| 71 |
+
if show_profile_card:
|
| 72 |
+
k = 1
|
| 73 |
+
results = await vector_store.asimilarity_search(
|
| 74 |
query,
|
| 75 |
k=k,
|
| 76 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
| 77 |
)
|
| 78 |
+
for result in results:
|
| 79 |
+
result.metadata["show_profile_card"] = show_profile_card
|
| 80 |
+
return results
|