Update trace.py
Browse files
trace.py
CHANGED
|
@@ -4,9 +4,7 @@ from wandb.sdk.data_types.trace_tree import Trace
|
|
| 4 |
|
| 5 |
WANDB_API_KEY = os.environ["WANDB_API_KEY"]
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def wandb_trace(rag_option,
|
| 10 |
prompt,
|
| 11 |
completion,
|
| 12 |
result,
|
|
@@ -23,27 +21,27 @@ def wandb_trace(rag_option,
|
|
| 23 |
name = "" if (chain == None) else type(chain).__name__,
|
| 24 |
status_code = "success" if (str(err_msg) == "") else "error",
|
| 25 |
status_message = str(err_msg),
|
| 26 |
-
metadata = {"chunk_overlap": "" if (
|
| 27 |
-
"chunk_size": "" if (
|
| 28 |
} if (str(err_msg) == "") else {},
|
| 29 |
-
inputs = {"
|
| 30 |
"prompt": prompt,
|
| 31 |
-
"chain_prompt": (str(chain.prompt) if (
|
| 32 |
str(chain.combine_documents_chain.llm_chain.prompt)),
|
| 33 |
-
"source_documents": "" if (
|
| 34 |
} if (str(err_msg) == "") else {},
|
| 35 |
outputs = {"result": result,
|
| 36 |
"generation_info": str(generation_info),
|
| 37 |
"llm_output": str(llm_output),
|
| 38 |
"completion": str(completion),
|
| 39 |
} if (str(err_msg) == "") else {},
|
| 40 |
-
model_dict = {"client": (str(chain.llm.client) if (
|
| 41 |
str(chain.combine_documents_chain.llm_chain.llm.client)),
|
| 42 |
-
"model_name": (str(chain.llm.model_name) if (
|
| 43 |
str(chain.combine_documents_chain.llm_chain.llm.model_name)),
|
| 44 |
-
"temperature": (str(chain.llm.temperature) if (
|
| 45 |
str(chain.combine_documents_chain.llm_chain.llm.temperature)),
|
| 46 |
-
"retriever": ("" if (
|
| 47 |
} if (str(err_msg) == "") else {},
|
| 48 |
start_time_ms = start_time_ms,
|
| 49 |
end_time_ms = end_time_ms
|
|
|
|
| 4 |
|
| 5 |
WANDB_API_KEY = os.environ["WANDB_API_KEY"]
|
| 6 |
|
| 7 |
+
def wandb_trace(is_rag_off,
|
|
|
|
|
|
|
| 8 |
prompt,
|
| 9 |
completion,
|
| 10 |
result,
|
|
|
|
| 21 |
name = "" if (chain == None) else type(chain).__name__,
|
| 22 |
status_code = "success" if (str(err_msg) == "") else "error",
|
| 23 |
status_message = str(err_msg),
|
| 24 |
+
metadata = {"chunk_overlap": "" if (is_rag_off) else config["chunk_overlap"],
|
| 25 |
+
"chunk_size": "" if (is_rag_off) else config["chunk_size"],
|
| 26 |
} if (str(err_msg) == "") else {},
|
| 27 |
+
inputs = {"is_rag": not is_rag_off,
|
| 28 |
"prompt": prompt,
|
| 29 |
+
"chain_prompt": (str(chain.prompt) if (is_rag_off) else
|
| 30 |
str(chain.combine_documents_chain.llm_chain.prompt)),
|
| 31 |
+
"source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
|
| 32 |
} if (str(err_msg) == "") else {},
|
| 33 |
outputs = {"result": result,
|
| 34 |
"generation_info": str(generation_info),
|
| 35 |
"llm_output": str(llm_output),
|
| 36 |
"completion": str(completion),
|
| 37 |
} if (str(err_msg) == "") else {},
|
| 38 |
+
model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
|
| 39 |
str(chain.combine_documents_chain.llm_chain.llm.client)),
|
| 40 |
+
"model_name": (str(chain.llm.model_name) if (is_rag_off) else
|
| 41 |
str(chain.combine_documents_chain.llm_chain.llm.model_name)),
|
| 42 |
+
"temperature": (str(chain.llm.temperature) if (is_rag_off) else
|
| 43 |
str(chain.combine_documents_chain.llm_chain.llm.temperature)),
|
| 44 |
+
"retriever": ("" if (is_rag_off) else str(chain.retriever)),
|
| 45 |
} if (str(err_msg) == "") else {},
|
| 46 |
start_time_ms = start_time_ms,
|
| 47 |
end_time_ms = end_time_ms
|