Spaces:
Runtime error
Runtime error
| import numpy as np # linear algebra | |
| import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
| import pprint | |
| import os | |
| import ast | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| import weaviate | |
| from weaviate.embedded import EmbeddedOptions | |
| from langchain_community.vectorstores import Weaviate | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.schema import Document | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain.schema.output_parser import StrOutputParser | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| df = pd.read_csv('./RAW_recipes.csv') | |
| # Variables | |
| max_length = 231637 #total number of recipes aka rows | |
| curr_len = 10000 # how much we want to process and embed | |
| #Concatenate all rows into one string | |
| curr_i = 0 | |
| recipe_info = [] | |
| for index, row in df.iterrows(): | |
| if curr_i >= curr_len: | |
| break | |
| curr_i+=1 | |
| name, id, minutes, contributor_id, submitted, tags, nutrition, n_steps, steps, description, ingredients, n_ingredients = row | |
| #convert to list | |
| nutrition = ast.literal_eval(nutrition) | |
| steps = ast.literal_eval(steps) | |
| #format nutrition | |
| nutrition_map = ["Calorie"," Total Fat", 'Sugar', 'Sodium', 'Protein', 'Saturated Fat', 'Total Carbohydrate'] | |
| nutrition_labeled = [] | |
| for label, num in zip(nutrition_map, nutrition): | |
| if label == "Calorie": | |
| nutrition_labeled.append(f"{label} : {num} per serving") | |
| else: | |
| nutrition_labeled.append(f"{label} : {num} % daily value") | |
| #format steps | |
| for i in range(len(steps)): | |
| steps[i] = f"{i+1}. " + steps[i] | |
| recipe_info.append(f''' | |
| {name} : {minutes} minutes, submitted on {submitted} | |
| description: {description}, | |
| ingredients: {ingredients} | |
| number of ingredients: {n_ingredients} | |
| tags: {tags}, nutrition: {nutrition_labeled}, total steps: {n_steps} | |
| steps: {steps} | |
| '''.replace("\r", "").replace("\n", "")) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150) | |
| #split into recipe_info into chunks | |
| docs = [] | |
| for doc in recipe_info: | |
| # Wrap each string in a Document object | |
| document = Document(page_content=doc) # create a Document object with the content | |
| chunk = text_splitter.split_documents([document]) # Pass a list of Document objects | |
| docs.append(chunk) | |
| # merge all chunks into one | |
| merged_documents = [] | |
| for doc in docs: | |
| merged_documents.extend(doc) | |
| # Hugging Face model for embeddings. | |
| model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| model_kwargs = {'device': 'cpu'} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| ) | |
| #initialize weaviate client | |
| client = weaviate.Client( | |
| embedded_options = EmbeddedOptions() | |
| ) | |
| vector_search = Weaviate.from_documents( | |
| client = client, | |
| documents = merged_documents, | |
| embedding = embeddings, | |
| by_text = False | |
| ) | |
| # Instantiate Weaviate Vector Search as a retriever | |
| # Basic RAG. | |
| # k to search for only the 25 most relevant documents. | |
| # score_threshold to use only documents with a relevance score above 0.77. | |
| k = 10 | |
| score_threshold = 0.77 | |
| retriever = vector_search.as_retriever( | |
| search_type = "mmr", | |
| search_kwargs = { | |
| "k": k, | |
| "score_threshold": score_threshold | |
| } | |
| ) | |
| template = """ | |
| You are an assistant for question-answering tasks. | |
| Use the following pieces of retrieved context to answer the question at the end. | |
| The following pieces of retrieved context are recipes. | |
| If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
| Dont say anthing mean or offensive. | |
| Context: {context} | |
| Question: {question} | |
| """ | |
| custom_rag_prompt = ChatPromptTemplate.from_template(template) | |
| llm = ChatOpenAI( | |
| model_name="gpt-3.5-turbo", | |
| temperature=0.2) | |
| # Regular chain format: chain = prompt | model | output_parser | |
| rag_chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | custom_rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| def get_response(query): | |
| return rag_chain.invoke(query) | |
| with gr.Blocks(theme=Base(), title="RAG Recipe AI") as demo: | |
| gr.Markdown(""" | |
| # RAG Recipe AI | |
| This model will answer all your recipe-related questions. | |
| Enter a question about a recipe, and the system will return an answer based on 10,000 food.com recipes stored in the vector database. \n | |
| Features Considered: \n | |
| \t - Cook Time | |
| \t - Nutrition Information | |
| \t - Steps | |
| \t - Ingredients | |
| \t - Dish Description | |
| Sample Queries: \n | |
| \t - What is an easy dessert I can make with apples? | |
| \t - What is the nutritional information of a Caesar salad? | |
| \t - How many calories is in an average American burger? | |
| """) | |
| textbox = gr.Textbox(label="Question:") | |
| with gr.Row(): | |
| button = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| output1 = gr.Textbox(lines=1, max_lines=10, label="Answer:") | |
| # Call get_response function upon clicking the Submit button. | |
| button.click(get_response, textbox, outputs=[output1]) | |
| demo.launch() | |