Spaces:
Runtime error
Runtime error
| import tempfile | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional, Sequence | |
| from langchain_core.agents import AgentAction, AgentFinish | |
| from langchain_core.outputs import Generation, LLMResult | |
| import langchain | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.callbacks.utils import ( | |
| BaseMetadataCallbackHandler, | |
| flatten_dict, | |
| import_pandas, | |
| import_spacy, | |
| import_textstat, | |
| ) | |
| LANGCHAIN_MODEL_NAME = "langchain-model" | |
| def import_comet_ml() -> Any: | |
| """Import comet_ml and raise an error if it is not installed.""" | |
| try: | |
| import comet_ml # noqa: F401 | |
| except ImportError: | |
| raise ImportError( | |
| "To use the comet_ml callback manager you need to have the " | |
| "`comet_ml` python package installed. Please install it with" | |
| " `pip install comet_ml`" | |
| ) | |
| return comet_ml | |
| def _get_experiment( | |
| workspace: Optional[str] = None, project_name: Optional[str] = None | |
| ) -> Any: | |
| comet_ml = import_comet_ml() | |
| experiment = comet_ml.Experiment( # type: ignore | |
| workspace=workspace, | |
| project_name=project_name, | |
| ) | |
| return experiment | |
| def _fetch_text_complexity_metrics(text: str) -> dict: | |
| textstat = import_textstat() | |
| text_complexity_metrics = { | |
| "flesch_reading_ease": textstat.flesch_reading_ease(text), | |
| "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), | |
| "smog_index": textstat.smog_index(text), | |
| "coleman_liau_index": textstat.coleman_liau_index(text), | |
| "automated_readability_index": textstat.automated_readability_index(text), | |
| "dale_chall_readability_score": textstat.dale_chall_readability_score(text), | |
| "difficult_words": textstat.difficult_words(text), | |
| "linsear_write_formula": textstat.linsear_write_formula(text), | |
| "gunning_fog": textstat.gunning_fog(text), | |
| "text_standard": textstat.text_standard(text), | |
| "fernandez_huerta": textstat.fernandez_huerta(text), | |
| "szigriszt_pazos": textstat.szigriszt_pazos(text), | |
| "gutierrez_polini": textstat.gutierrez_polini(text), | |
| "crawford": textstat.crawford(text), | |
| "gulpease_index": textstat.gulpease_index(text), | |
| "osman": textstat.osman(text), | |
| } | |
| return text_complexity_metrics | |
| def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict: | |
| pd = import_pandas() | |
| metrics_df = pd.DataFrame(metrics) | |
| metrics_summary = metrics_df.describe() | |
| return metrics_summary.to_dict() | |
| class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): | |
| """Callback Handler that logs to Comet. | |
| Parameters: | |
| job_type (str): The type of comet_ml task such as "inference", | |
| "testing" or "qc" | |
| project_name (str): The comet_ml project name | |
| tags (list): Tags to add to the task | |
| task_name (str): Name of the comet_ml task | |
| visualize (bool): Whether to visualize the run. | |
| complexity_metrics (bool): Whether to log complexity metrics | |
| stream_logs (bool): Whether to stream callback actions to Comet | |
| This handler will utilize the associated callback method and formats | |
| the input of each callback function with metadata regarding the state of LLM run, | |
| and adds the response to the list of records for both the {method}_records and | |
| action. It then logs the response to Comet. | |
| """ | |
| def __init__( | |
| self, | |
| task_type: Optional[str] = "inference", | |
| workspace: Optional[str] = None, | |
| project_name: Optional[str] = None, | |
| tags: Optional[Sequence] = None, | |
| name: Optional[str] = None, | |
| visualizations: Optional[List[str]] = None, | |
| complexity_metrics: bool = False, | |
| custom_metrics: Optional[Callable] = None, | |
| stream_logs: bool = True, | |
| ) -> None: | |
| """Initialize callback handler.""" | |
| self.comet_ml = import_comet_ml() | |
| super().__init__() | |
| self.task_type = task_type | |
| self.workspace = workspace | |
| self.project_name = project_name | |
| self.tags = tags | |
| self.visualizations = visualizations | |
| self.complexity_metrics = complexity_metrics | |
| self.custom_metrics = custom_metrics | |
| self.stream_logs = stream_logs | |
| self.temp_dir = tempfile.TemporaryDirectory() | |
| self.experiment = _get_experiment(workspace, project_name) | |
| self.experiment.log_other("Created from", "langchain") | |
| if tags: | |
| self.experiment.add_tags(tags) | |
| self.name = name | |
| if self.name: | |
| self.experiment.set_name(self.name) | |
| warning = ( | |
| "The comet_ml callback is currently in beta and is subject to change " | |
| "based on updates to `langchain`. Please report any issues to " | |
| "https://github.com/comet-ml/issue-tracking/issues with the tag " | |
| "`langchain`." | |
| ) | |
| self.comet_ml.LOGGER.warning(warning) | |
| self.callback_columns: list = [] | |
| self.action_records: list = [] | |
| self.complexity_metrics = complexity_metrics | |
| if self.visualizations: | |
| spacy = import_spacy() | |
| self.nlp = spacy.load("en_core_web_sm") | |
| else: | |
| self.nlp = None | |
| def _init_resp(self) -> Dict: | |
| return {k: None for k in self.callback_columns} | |
| def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM starts.""" | |
| self.step += 1 | |
| self.llm_starts += 1 | |
| self.starts += 1 | |
| metadata = self._init_resp() | |
| metadata.update({"action": "on_llm_start"}) | |
| metadata.update(flatten_dict(serialized)) | |
| metadata.update(self.get_custom_callback_meta()) | |
| for prompt in prompts: | |
| prompt_resp = deepcopy(metadata) | |
| prompt_resp["prompts"] = prompt | |
| self.on_llm_start_records.append(prompt_resp) | |
| self.action_records.append(prompt_resp) | |
| if self.stream_logs: | |
| self._log_stream(prompt, metadata, self.step) | |
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| """Run when LLM generates a new token.""" | |
| self.step += 1 | |
| self.llm_streams += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_llm_new_token", "token": token}) | |
| resp.update(self.get_custom_callback_meta()) | |
| self.action_records.append(resp) | |
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
| """Run when LLM ends running.""" | |
| self.step += 1 | |
| self.llm_ends += 1 | |
| self.ends += 1 | |
| metadata = self._init_resp() | |
| metadata.update({"action": "on_llm_end"}) | |
| metadata.update(flatten_dict(response.llm_output or {})) | |
| metadata.update(self.get_custom_callback_meta()) | |
| output_complexity_metrics = [] | |
| output_custom_metrics = [] | |
| for prompt_idx, generations in enumerate(response.generations): | |
| for gen_idx, generation in enumerate(generations): | |
| text = generation.text | |
| generation_resp = deepcopy(metadata) | |
| generation_resp.update(flatten_dict(generation.dict())) | |
| complexity_metrics = self._get_complexity_metrics(text) | |
| if complexity_metrics: | |
| output_complexity_metrics.append(complexity_metrics) | |
| generation_resp.update(complexity_metrics) | |
| custom_metrics = self._get_custom_metrics( | |
| generation, prompt_idx, gen_idx | |
| ) | |
| if custom_metrics: | |
| output_custom_metrics.append(custom_metrics) | |
| generation_resp.update(custom_metrics) | |
| if self.stream_logs: | |
| self._log_stream(text, metadata, self.step) | |
| self.action_records.append(generation_resp) | |
| self.on_llm_end_records.append(generation_resp) | |
| self._log_text_metrics(output_complexity_metrics, step=self.step) | |
| self._log_text_metrics(output_custom_metrics, step=self.step) | |
| def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | |
| """Run when LLM errors.""" | |
| self.step += 1 | |
| self.errors += 1 | |
| def on_chain_start( | |
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
| ) -> None: | |
| """Run when chain starts running.""" | |
| self.step += 1 | |
| self.chain_starts += 1 | |
| self.starts += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_chain_start"}) | |
| resp.update(flatten_dict(serialized)) | |
| resp.update(self.get_custom_callback_meta()) | |
| for chain_input_key, chain_input_val in inputs.items(): | |
| if isinstance(chain_input_val, str): | |
| input_resp = deepcopy(resp) | |
| if self.stream_logs: | |
| self._log_stream(chain_input_val, resp, self.step) | |
| input_resp.update({chain_input_key: chain_input_val}) | |
| self.action_records.append(input_resp) | |
| else: | |
| self.comet_ml.LOGGER.warning( | |
| f"Unexpected data format provided! " | |
| f"Input Value for {chain_input_key} will not be logged" | |
| ) | |
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
| """Run when chain ends running.""" | |
| self.step += 1 | |
| self.chain_ends += 1 | |
| self.ends += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_chain_end"}) | |
| resp.update(self.get_custom_callback_meta()) | |
| for chain_output_key, chain_output_val in outputs.items(): | |
| if isinstance(chain_output_val, str): | |
| output_resp = deepcopy(resp) | |
| if self.stream_logs: | |
| self._log_stream(chain_output_val, resp, self.step) | |
| output_resp.update({chain_output_key: chain_output_val}) | |
| self.action_records.append(output_resp) | |
| else: | |
| self.comet_ml.LOGGER.warning( | |
| f"Unexpected data format provided! " | |
| f"Output Value for {chain_output_key} will not be logged" | |
| ) | |
| def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | |
| """Run when chain errors.""" | |
| self.step += 1 | |
| self.errors += 1 | |
| def on_tool_start( | |
| self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
| ) -> None: | |
| """Run when tool starts running.""" | |
| self.step += 1 | |
| self.tool_starts += 1 | |
| self.starts += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_tool_start"}) | |
| resp.update(flatten_dict(serialized)) | |
| resp.update(self.get_custom_callback_meta()) | |
| if self.stream_logs: | |
| self._log_stream(input_str, resp, self.step) | |
| resp.update({"input_str": input_str}) | |
| self.action_records.append(resp) | |
| def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
| """Run when tool ends running.""" | |
| self.step += 1 | |
| self.tool_ends += 1 | |
| self.ends += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_tool_end"}) | |
| resp.update(self.get_custom_callback_meta()) | |
| if self.stream_logs: | |
| self._log_stream(output, resp, self.step) | |
| resp.update({"output": output}) | |
| self.action_records.append(resp) | |
| def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | |
| """Run when tool errors.""" | |
| self.step += 1 | |
| self.errors += 1 | |
| def on_text(self, text: str, **kwargs: Any) -> None: | |
| """ | |
| Run when agent is ending. | |
| """ | |
| self.step += 1 | |
| self.text_ctr += 1 | |
| resp = self._init_resp() | |
| resp.update({"action": "on_text"}) | |
| resp.update(self.get_custom_callback_meta()) | |
| if self.stream_logs: | |
| self._log_stream(text, resp, self.step) | |
| resp.update({"text": text}) | |
| self.action_records.append(resp) | |
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
| """Run when agent ends running.""" | |
| self.step += 1 | |
| self.agent_ends += 1 | |
| self.ends += 1 | |
| resp = self._init_resp() | |
| output = finish.return_values["output"] | |
| log = finish.log | |
| resp.update({"action": "on_agent_finish", "log": log}) | |
| resp.update(self.get_custom_callback_meta()) | |
| if self.stream_logs: | |
| self._log_stream(output, resp, self.step) | |
| resp.update({"output": output}) | |
| self.action_records.append(resp) | |
| def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
| """Run on agent action.""" | |
| self.step += 1 | |
| self.tool_starts += 1 | |
| self.starts += 1 | |
| tool = action.tool | |
| tool_input = str(action.tool_input) | |
| log = action.log | |
| resp = self._init_resp() | |
| resp.update({"action": "on_agent_action", "log": log, "tool": tool}) | |
| resp.update(self.get_custom_callback_meta()) | |
| if self.stream_logs: | |
| self._log_stream(tool_input, resp, self.step) | |
| resp.update({"tool_input": tool_input}) | |
| self.action_records.append(resp) | |
| def _get_complexity_metrics(self, text: str) -> dict: | |
| """Compute text complexity metrics using textstat. | |
| Parameters: | |
| text (str): The text to analyze. | |
| Returns: | |
| (dict): A dictionary containing the complexity metrics. | |
| """ | |
| resp = {} | |
| if self.complexity_metrics: | |
| text_complexity_metrics = _fetch_text_complexity_metrics(text) | |
| resp.update(text_complexity_metrics) | |
| return resp | |
| def _get_custom_metrics( | |
| self, generation: Generation, prompt_idx: int, gen_idx: int | |
| ) -> dict: | |
| """Compute Custom Metrics for an LLM Generated Output | |
| Args: | |
| generation (LLMResult): Output generation from an LLM | |
| prompt_idx (int): List index of the input prompt | |
| gen_idx (int): List index of the generated output | |
| Returns: | |
| dict: A dictionary containing the custom metrics. | |
| """ | |
| resp = {} | |
| if self.custom_metrics: | |
| custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx) | |
| resp.update(custom_metrics) | |
| return resp | |
| def flush_tracker( | |
| self, | |
| langchain_asset: Any = None, | |
| task_type: Optional[str] = "inference", | |
| workspace: Optional[str] = None, | |
| project_name: Optional[str] = "comet-langchain-demo", | |
| tags: Optional[Sequence] = None, | |
| name: Optional[str] = None, | |
| visualizations: Optional[List[str]] = None, | |
| complexity_metrics: bool = False, | |
| custom_metrics: Optional[Callable] = None, | |
| finish: bool = False, | |
| reset: bool = False, | |
| ) -> None: | |
| """Flush the tracker and setup the session. | |
| Everything after this will be a new table. | |
| Args: | |
| name: Name of the performed session so far so it is identifiable | |
| langchain_asset: The langchain asset to save. | |
| finish: Whether to finish the run. | |
| Returns: | |
| None | |
| """ | |
| self._log_session(langchain_asset) | |
| if langchain_asset: | |
| try: | |
| self._log_model(langchain_asset) | |
| except Exception: | |
| self.comet_ml.LOGGER.error( | |
| "Failed to export agent or LLM to Comet", | |
| exc_info=True, | |
| extra={"show_traceback": True}, | |
| ) | |
| if finish: | |
| self.experiment.end() | |
| if reset: | |
| self._reset( | |
| task_type, | |
| workspace, | |
| project_name, | |
| tags, | |
| name, | |
| visualizations, | |
| complexity_metrics, | |
| custom_metrics, | |
| ) | |
| def _log_stream(self, prompt: str, metadata: dict, step: int) -> None: | |
| self.experiment.log_text(prompt, metadata=metadata, step=step) | |
| def _log_model(self, langchain_asset: Any) -> None: | |
| model_parameters = self._get_llm_parameters(langchain_asset) | |
| self.experiment.log_parameters(model_parameters, prefix="model") | |
| langchain_asset_path = Path(self.temp_dir.name, "model.json") | |
| model_name = self.name if self.name else LANGCHAIN_MODEL_NAME | |
| try: | |
| if hasattr(langchain_asset, "save"): | |
| langchain_asset.save(langchain_asset_path) | |
| self.experiment.log_model(model_name, str(langchain_asset_path)) | |
| except (ValueError, AttributeError, NotImplementedError) as e: | |
| if hasattr(langchain_asset, "save_agent"): | |
| langchain_asset.save_agent(langchain_asset_path) | |
| self.experiment.log_model(model_name, str(langchain_asset_path)) | |
| else: | |
| self.comet_ml.LOGGER.error( | |
| f"{e}" | |
| " Could not save Langchain Asset " | |
| f"for {langchain_asset.__class__.__name__}" | |
| ) | |
| def _log_session(self, langchain_asset: Optional[Any] = None) -> None: | |
| try: | |
| llm_session_df = self._create_session_analysis_dataframe(langchain_asset) | |
| # Log the cleaned dataframe as a table | |
| self.experiment.log_table("langchain-llm-session.csv", llm_session_df) | |
| except Exception: | |
| self.comet_ml.LOGGER.warning( | |
| "Failed to log session data to Comet", | |
| exc_info=True, | |
| extra={"show_traceback": True}, | |
| ) | |
| try: | |
| metadata = {"langchain_version": str(langchain.__version__)} | |
| # Log the langchain low-level records as a JSON file directly | |
| self.experiment.log_asset_data( | |
| self.action_records, "langchain-action_records.json", metadata=metadata | |
| ) | |
| except Exception: | |
| self.comet_ml.LOGGER.warning( | |
| "Failed to log session data to Comet", | |
| exc_info=True, | |
| extra={"show_traceback": True}, | |
| ) | |
| try: | |
| self._log_visualizations(llm_session_df) | |
| except Exception: | |
| self.comet_ml.LOGGER.warning( | |
| "Failed to log visualizations to Comet", | |
| exc_info=True, | |
| extra={"show_traceback": True}, | |
| ) | |
| def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: | |
| if not metrics: | |
| return | |
| metrics_summary = _summarize_metrics_for_generated_outputs(metrics) | |
| for key, value in metrics_summary.items(): | |
| self.experiment.log_metrics(value, prefix=key, step=step) | |
| def _log_visualizations(self, session_df: Any) -> None: | |
| if not (self.visualizations and self.nlp): | |
| return | |
| spacy = import_spacy() | |
| prompts = session_df["prompts"].tolist() | |
| outputs = session_df["text"].tolist() | |
| for idx, (prompt, output) in enumerate(zip(prompts, outputs)): | |
| doc = self.nlp(output) | |
| sentence_spans = list(doc.sents) | |
| for visualization in self.visualizations: | |
| try: | |
| html = spacy.displacy.render( | |
| sentence_spans, | |
| style=visualization, | |
| options={"compact": True}, | |
| jupyter=False, | |
| page=True, | |
| ) | |
| self.experiment.log_asset_data( | |
| html, | |
| name=f"langchain-viz-{visualization}-{idx}.html", | |
| metadata={"prompt": prompt}, | |
| step=idx, | |
| ) | |
| except Exception as e: | |
| self.comet_ml.LOGGER.warning( | |
| e, exc_info=True, extra={"show_traceback": True} | |
| ) | |
| return | |
| def _reset( | |
| self, | |
| task_type: Optional[str] = None, | |
| workspace: Optional[str] = None, | |
| project_name: Optional[str] = None, | |
| tags: Optional[Sequence] = None, | |
| name: Optional[str] = None, | |
| visualizations: Optional[List[str]] = None, | |
| complexity_metrics: bool = False, | |
| custom_metrics: Optional[Callable] = None, | |
| ) -> None: | |
| _task_type = task_type if task_type else self.task_type | |
| _workspace = workspace if workspace else self.workspace | |
| _project_name = project_name if project_name else self.project_name | |
| _tags = tags if tags else self.tags | |
| _name = name if name else self.name | |
| _visualizations = visualizations if visualizations else self.visualizations | |
| _complexity_metrics = ( | |
| complexity_metrics if complexity_metrics else self.complexity_metrics | |
| ) | |
| _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics | |
| self.__init__( # type: ignore | |
| task_type=_task_type, | |
| workspace=_workspace, | |
| project_name=_project_name, | |
| tags=_tags, | |
| name=_name, | |
| visualizations=_visualizations, | |
| complexity_metrics=_complexity_metrics, | |
| custom_metrics=_custom_metrics, | |
| ) | |
| self.reset_callback_meta() | |
| self.temp_dir = tempfile.TemporaryDirectory() | |
| def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict: | |
| pd = import_pandas() | |
| llm_parameters = self._get_llm_parameters(langchain_asset) | |
| num_generations_per_prompt = llm_parameters.get("n", 1) | |
| llm_start_records_df = pd.DataFrame(self.on_llm_start_records) | |
| # Repeat each input row based on the number of outputs generated per prompt | |
| llm_start_records_df = llm_start_records_df.loc[ | |
| llm_start_records_df.index.repeat(num_generations_per_prompt) | |
| ].reset_index(drop=True) | |
| llm_end_records_df = pd.DataFrame(self.on_llm_end_records) | |
| llm_session_df = pd.merge( | |
| llm_start_records_df, | |
| llm_end_records_df, | |
| left_index=True, | |
| right_index=True, | |
| suffixes=["_llm_start", "_llm_end"], | |
| ) | |
| return llm_session_df | |
| def _get_llm_parameters(self, langchain_asset: Any = None) -> dict: | |
| if not langchain_asset: | |
| return {} | |
| try: | |
| if hasattr(langchain_asset, "agent"): | |
| llm_parameters = langchain_asset.agent.llm_chain.llm.dict() | |
| elif hasattr(langchain_asset, "llm_chain"): | |
| llm_parameters = langchain_asset.llm_chain.llm.dict() | |
| elif hasattr(langchain_asset, "llm"): | |
| llm_parameters = langchain_asset.llm.dict() | |
| else: | |
| llm_parameters = langchain_asset.dict() | |
| except Exception: | |
| return {} | |
| return llm_parameters | |