| | from operator import itemgetter |
| | from typing import List |
| | import json |
| |
|
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langchain_openai import ChatOpenAI |
| | from langchain_core.runnables import RunnablePassthrough |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
| | from backend.app.vectorstore import get_vector_db |
| |
|
| | MODEL = "gpt-3.5-turbo" |
| |
|
| | SYSTEM_ROLE_PROMPT = """ |
| | You are a helpful assistant that generates questions based on a given context. |
| | """ |
| |
|
| | USER_ROLE_PROMPT = """ |
| | Based on the following context about {query}, generate 5 relevant and specific questions. |
| | Make sure the questions can be answered using only the provided context. |
| | |
| | Context: {context} |
| | |
| | Generate 5 questions that test understanding of the material in the context. |
| | |
| | Return only a json object with the following format: |
| | {{ |
| | "questions": ["question1", "question2", "question3", "question4", "question5"] |
| | }} |
| | """ |
| |
|
| |
|
| | class ProblemGenerationPipeline: |
| | def __init__(self, return_context: bool = False, embedding_model_id: str = None): |
| | self.chat_prompt = ChatPromptTemplate.from_messages( |
| | [("system", SYSTEM_ROLE_PROMPT), ("user", USER_ROLE_PROMPT)] |
| | ) |
| |
|
| | self.llm = ChatOpenAI(model=MODEL, temperature=0.7) |
| | self.retriever = get_vector_db(embedding_model_id).as_retriever( |
| | search_kwargs={"k": 2} |
| | ) |
| |
|
| | |
| | self.return_context = return_context |
| | if not return_context: |
| | self.rag_chain = ( |
| | {"context": self.retriever, "query": RunnablePassthrough()} |
| | | self.chat_prompt |
| | | self.llm |
| | | StrOutputParser() |
| | ) |
| | else: |
| | |
| | self.rag_chain = ( |
| | { |
| | "context": itemgetter("query") | self.retriever, |
| | "query": itemgetter("query"), |
| | } |
| | | RunnablePassthrough.assign(context=itemgetter("context")) |
| | | { |
| | "response": self.chat_prompt | self.llm | StrOutputParser(), |
| | "context": itemgetter("context"), |
| | } |
| | ) |
| |
|
| | def generate_problems(self, query: str, debug: bool = False) -> List[str]: |
| | """ |
| | Generate problems based on the user's query using RAG. |
| | |
| | Args: |
| | query (str): The topic to generate questions about |
| | |
| | Returns: |
| | List[str]: A list of generated questions |
| | """ |
| | raw_result = self.rag_chain.invoke(query) |
| | if debug: |
| | print(raw_result) |
| | |
| | if self.return_context: |
| | return raw_result |
| | |
| | else: |
| | return json.loads(raw_result)["questions"] |
| |
|