from indexer import load_reddit_posts, create_vector_db, create_qa_chain import gradio as gr # Gradio interface function def answer_question(subreddit, num_posts, category, question): docs = load_reddit_posts(subreddit, int(num_posts), category) db = create_vector_db(docs) qa_chain = create_qa_chain(db) result = qa_chain(question) output = result["result"] # Remove instruction sentence output = output.replace( "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.", "", ) # Remove everything after "Example:" output = output.split("Example:")[0] return output.strip() # Gradio UI demo = gr.Interface( fn=answer_question, inputs=[ gr.Textbox(label="Subreddit (e.g. MachineLearning)"), gr.Slider(minimum=1, maximum=5, value=1, label="Number of Reddit Posts"), gr.Textbox(label="Category (e.g. hot, new, etc)"), gr.Textbox(label="Your Question"), ], outputs="text", title="Reddit RAG Assistant", description="Ask questions based on recent Reddit posts in a subreddit.", ) demo.launch()