Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import pymongo | |
| import spaces | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| def get_embedding(text: str) -> list[float]: | |
| if not text.strip(): | |
| print("Attempted to get embedding for empty text.") | |
| return [] | |
| embedding = embedding_model.encode(text) | |
| return embedding.tolist() | |
| def get_mongo_client(mongo_uri): | |
| """Establish connection to the MongoDB.""" | |
| try: | |
| client = pymongo.MongoClient(mongo_uri) | |
| print("Connection to MongoDB successful") | |
| return client | |
| except pymongo.errors.ConnectionFailure as e: | |
| print(f"Connection failed: {e}") | |
| return None | |
| def vector_search(user_query, collection): | |
| # Generate embedding for the user query | |
| query_embedding = get_embedding(user_query) | |
| if query_embedding is None: | |
| return "Invalid query or embedding generation failed." | |
| # Define the vector search pipeline | |
| pipeline = [ | |
| { | |
| "$vectorSearch": { | |
| "index": "vector_index", | |
| "queryVector": query_embedding, | |
| "path": "embedding", | |
| "numCandidates": 150, # Number of candidate matches to consider | |
| "limit": 4, # Return top 4 matches | |
| } | |
| }, | |
| { | |
| "$project": { | |
| "_id": 0, | |
| "title": 1, | |
| "ingredients": 1, | |
| "directions": 1, | |
| "score": {"$meta": "vectorSearchScore"}, # Include the search score | |
| } | |
| }, | |
| ] | |
| # Execute the search | |
| results = collection.aggregate(pipeline) | |
| return list(results) | |
| def get_search_result(query, collection): | |
| get_knowledge = vector_search(query, collection) | |
| search_result = "" | |
| for result in get_knowledge: | |
| search_result += f"Recipe Name: {result.get('title', 'N/A')}, Ingredients: {result.get('ingredients', 'N/A')}, Directions: {result.get('directions', 'N/A')}\n" | |
| return search_result, get_knowledge | |
| def process_response(message, history): | |
| source_information, matches = get_search_result(message, collection) | |
| recipe_dict = {} | |
| for x in matches: | |
| name = x.pop("title") | |
| recipe_dict[name] = x | |
| combined_information = f"Query: {message}\nContinue to answer the query by using the Search Results:\n{source_information}." | |
| input_ids = tokenizer(combined_information, return_tensors="pt").to("cuda") | |
| response = model.generate(**input_ids, max_new_tokens=500) | |
| response_text = tokenizer.decode(response[0]).split("\n.\n")[-1].split("<eos>")[0].strip() | |
| matched_recipe = "" | |
| for title in recipe_dict.keys(): | |
| if title in response_text: | |
| matched_recipe = title | |
| break | |
| if not matched_recipe: | |
| matched_recipe = next(iter(recipe_dict)) | |
| recipe = recipe_dict[matched_recipe] | |
| response_text += f"\n\nRecipe for **{matched_recipe}**:" | |
| response_text += "\n### List of ingredients:\n- {0}".format("\n- ".join(recipe["ingredients"].split(", "))) | |
| response_text += "\n### Directions:\n- {0}".format(".\n- ".join(recipe["directions"].split(". "))) | |
| return response_text | |
| if __name__ == "__main__": | |
| embedding_model = SentenceTransformer("thenlper/gte-large") | |
| mongo_uri = os.getenv("MONGO_URI") | |
| if not mongo_uri: | |
| raise ValueError("MONGO_URI not set in environment variables") | |
| mongo_client = get_mongo_client(mongo_uri) | |
| # Ingest data into MongoDB | |
| db = mongo_client["recipe"] | |
| collection = db["recipe_collection"] | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") | |
| model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto") | |
| gr.ChatInterface(process_response).queue().launch() |