import fix + dealing w/ response in function of n parameter in generation
Browse files- OpenAIChatAtomicFlow.py +2 -2
- OpenAIChatAtomicFlow.yaml +1 -1
- run.py +1 -1
OpenAIChatAtomicFlow.py
CHANGED
|
@@ -13,7 +13,7 @@ from flows.messages.flow_message import UpdateMessage_ChatMessage
|
|
| 13 |
|
| 14 |
from flows.prompt_template import JinjaPrompt
|
| 15 |
|
| 16 |
-
from backends.llm_lite import LiteLLMBackend
|
| 17 |
|
| 18 |
log = logging.get_logger(__name__)
|
| 19 |
|
|
@@ -256,5 +256,5 @@ class OpenAIChatAtomicFlow(AtomicFlow):
|
|
| 256 |
role=self.flow_config["assistant_name"],
|
| 257 |
content=answer
|
| 258 |
)
|
| 259 |
-
|
| 260 |
return {"api_output": response}
|
|
|
|
| 13 |
|
| 14 |
from flows.prompt_template import JinjaPrompt
|
| 15 |
|
| 16 |
+
from flows.backends.llm_lite import LiteLLMBackend
|
| 17 |
|
| 18 |
log = logging.get_logger(__name__)
|
| 19 |
|
|
|
|
| 256 |
role=self.flow_config["assistant_name"],
|
| 257 |
content=answer
|
| 258 |
)
|
| 259 |
+
response = response if len(response) > 1 or len(response) == 0 else response[0]
|
| 260 |
return {"api_output": response}
|
OpenAIChatAtomicFlow.yaml
CHANGED
|
@@ -9,7 +9,7 @@ user_name: user
|
|
| 9 |
assistant_name: assistant
|
| 10 |
|
| 11 |
backend:
|
| 12 |
-
_target_: backends.llm_lite.LiteLLMBackend
|
| 13 |
api_infos: ???
|
| 14 |
model_name: "gpt-3.5-turbo"
|
| 15 |
n: 1
|
|
|
|
| 9 |
assistant_name: assistant
|
| 10 |
|
| 11 |
backend:
|
| 12 |
+
_target_: flows.backends.llm_lite.LiteLLMBackend
|
| 13 |
api_infos: ???
|
| 14 |
model_name: "gpt-3.5-turbo"
|
| 15 |
n: 1
|
run.py
CHANGED
|
@@ -4,7 +4,7 @@ import hydra
|
|
| 4 |
|
| 5 |
import flows
|
| 6 |
from flows.flow_launchers import FlowLauncher
|
| 7 |
-
from backends.api_info import ApiInfo
|
| 8 |
from flows.utils.general_helpers import read_yaml_file
|
| 9 |
|
| 10 |
from flows import logging
|
|
|
|
| 4 |
|
| 5 |
import flows
|
| 6 |
from flows.flow_launchers import FlowLauncher
|
| 7 |
+
from flows.backends.api_info import ApiInfo
|
| 8 |
from flows.utils.general_helpers import read_yaml_file
|
| 9 |
|
| 10 |
from flows import logging
|