| from rag_fns.generation import do_generation | |
| from rag_fns.retrieval import do_retrieval | |
| from rag_fns.setup_load import import_data, load_oai_model | |
| def do_rag(user_input: str, stream: bool = False, n_results: int = 3): | |
| # Load the data | |
| talk_ids, embeds, talk_info = import_data() | |
| # Load the model | |
| oai_client = load_oai_model() | |
| retrieved_docs = do_retrieval( | |
| query0=user_input, | |
| n_results=n_results, | |
| api_client=oai_client, | |
| talk_ids=talk_ids, | |
| embeds=embeds, | |
| talk_info=talk_info, | |
| ) | |
| response, prompt_tokens = do_generation( | |
| query1=user_input, keep_texts=retrieved_docs, gen_client=oai_client, stream=stream | |
| ) | |
| return response, retrieved_docs, prompt_tokens | |