Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import json | |
| import textwrap | |
| from typing import Dict, Any, List | |
| from sql_formatter.core import format_sql | |
| from langchain.callbacks.streamlit.streamlit_callback_handler import ( | |
| LLMThought, | |
| StreamlitCallbackHandler, | |
| ) | |
| from langchain.schema.output import LLMResult | |
| class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler): | |
| def __init__(self) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Working...") | |
| self.tokens_stream = "" | |
| def on_llm_start(self, serialized, prompts, **kwargs) -> None: | |
| pass | |
| def on_text(self, text: str, **kwargs) -> None: | |
| self.progress_bar.progress(value=0.2, text="Asking LLM...") | |
| def on_chain_end(self, outputs, **kwargs) -> None: | |
| self.progress_bar.progress(value=0.6, text="Searching in DB...") | |
| if "repr" in outputs: | |
| st.markdown("### Generated Filter") | |
| st.markdown( | |
| f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True) | |
| def on_chain_start(self, serialized, inputs, **kwargs) -> None: | |
| pass | |
| class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler): | |
| def __init__(self) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Searching DB...") | |
| self.status_bar = st.empty() | |
| self.prog_value = 0.0 | |
| self.prog_map = { | |
| "langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2, | |
| "langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4, | |
| "langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8, | |
| } | |
| def on_llm_start(self, serialized, prompts, **kwargs) -> None: | |
| pass | |
| def on_text(self, text: str, **kwargs) -> None: | |
| pass | |
| def on_chain_start(self, serialized, inputs, **kwargs) -> None: | |
| cid = ".".join(serialized["id"]) | |
| if cid != "langchain.chains.llm.LLMChain": | |
| self.progress_bar.progress( | |
| value=self.prog_map[cid], text=f"Running Chain `{cid}`..." | |
| ) | |
| self.prog_value = self.prog_map[cid] | |
| else: | |
| self.prog_value += 0.1 | |
| self.progress_bar.progress( | |
| value=self.prog_value, text=f"Running Chain `{cid}`..." | |
| ) | |
| def on_chain_end(self, outputs, **kwargs) -> None: | |
| pass | |
| class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler): | |
| def __init__(self) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
| self.status_bar = st.empty() | |
| self.prog_value = 0 | |
| self.prog_interval = 0.2 | |
| def on_llm_start(self, serialized, prompts, **kwargs) -> None: | |
| pass | |
| def on_llm_end( | |
| self, | |
| response: LLMResult, | |
| *args, | |
| **kwargs, | |
| ): | |
| text = response.generations[0][0].text | |
| if text.replace(" ", "").upper().startswith("SELECT"): | |
| st.write("We generated Vector SQL for you:") | |
| st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""") | |
| print(f"Vector SQL: {text}") | |
| self.prog_value += self.prog_interval | |
| self.progress_bar.progress( | |
| value=self.prog_value, text="Searching in DB...") | |
| def on_chain_start(self, serialized, inputs, **kwargs) -> None: | |
| cid = ".".join(serialized["id"]) | |
| self.prog_value += self.prog_interval | |
| self.progress_bar.progress( | |
| value=self.prog_value, text=f"Running Chain `{cid}`..." | |
| ) | |
| def on_chain_end(self, outputs, **kwargs) -> None: | |
| pass | |
| class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler): | |
| def __init__(self) -> None: | |
| self.progress_bar = st.progress(value=0.0, text="Writing SQL...") | |
| self.status_bar = st.empty() | |
| self.prog_value = 0 | |
| self.prog_interval = 0.1 | |
| class LLMThoughtWithKB(LLMThought): | |
| def on_tool_end( | |
| self, | |
| output: str, | |
| color=None, | |
| observation_prefix=None, | |
| llm_prefix=None, | |
| **kwargs: Any, | |
| ) -> None: | |
| try: | |
| self._container.markdown( | |
| "\n\n".join( | |
| ["### Retrieved Documents:"] | |
| + [ | |
| f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}" | |
| for i, r in enumerate(json.loads(output)) | |
| ] | |
| ) | |
| ) | |
| except Exception as e: | |
| super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs) | |
| class ChatDataAgentCallBackHandler(StreamlitCallbackHandler): | |
| def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| if self._current_thought is None: | |
| self._current_thought = LLMThoughtWithKB( | |
| parent_container=self._parent_container, | |
| expanded=self._expand_new_thoughts, | |
| collapse_on_complete=self._collapse_completed_thoughts, | |
| labeler=self._thought_labeler, | |
| ) | |
| self._current_thought.on_llm_start(serialized, prompts) | |