Spaces:
Sleeping
Sleeping
| import os | |
| import csv | |
| import ast | |
| import pickle | |
| import gdown | |
| import torch | |
| import torch.nn.functional as F | |
| import streamlit as st | |
| from transformers import BertTokenizer, BertModel | |
| from config import GOOGLE_DRIVE_FILES | |
| def download_file_from_drive(file_id: str, destination: str, file_name: str) -> bool: | |
| try: | |
| with st.spinner(f"Downloading {file_name}..."): | |
| url = f"https://drive.google.com/uc?id={file_id}" | |
| gdown.download(url, destination, quiet=False) | |
| return True | |
| except Exception as e: | |
| st.error(f"Failed to download {file_name}: {e}") | |
| return False | |
| def ensure_files_downloaded(): | |
| for filename, file_id in GOOGLE_DRIVE_FILES.items(): | |
| if not os.path.exists(filename): | |
| success = download_file_from_drive(file_id, filename, filename) | |
| if not success: | |
| return False | |
| return True | |
| class GoogleDriveRecipeSearch: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not ensure_files_downloaded(): | |
| self.is_ready = False | |
| return | |
| self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| self.model = BertModel.from_pretrained("bert-base-uncased") | |
| if os.path.exists("assets/nlp/tag_based_bert_model.pth"): | |
| self.model.load_state_dict( | |
| torch.load("assets/nlp/tag_based_bert_model.pth", map_location=self.device) | |
| ) | |
| st.success("Trained model loaded successfully!") | |
| else: | |
| st.warning("Using untrained model") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.load_data() | |
| self.is_ready = True | |
| def load_data(self): | |
| self.recipe_embeddings = torch.load("assets/nlp/torch_recipe_embeddings_231630.pt", map_location=self.device) | |
| self.recipes = self._load_recipes("assets/nlp/RAW_recipes.csv") | |
| self.recipe_stats = pickle.load(open("assets/nlp/recipe_statistics_231630.pkl", "rb")) | |
| self.recipe_scores = pickle.load(open("assets/nlp/recipe_scores_231630.pkl", "rb")) | |
| def _load_recipes(self, path): | |
| recipes = [] | |
| with open(path, "r", encoding="utf-8") as file: | |
| reader = csv.DictReader(file) | |
| for idx, row in enumerate(reader): | |
| name = row.get("name", "").strip() | |
| if not name or name.lower() in ["nan", "unknown recipe"]: | |
| continue | |
| try: | |
| recipe = { | |
| "id": int(row.get("id", idx)), | |
| "name": name, | |
| "ingredients": ast.literal_eval(row.get("ingredients", "[]")), | |
| "tags": ast.literal_eval(row.get("tags", "[]")), | |
| "minutes": int(float(row.get("minutes", 0))), | |
| "n_steps": int(float(row.get("n_steps", 0))), | |
| "description": row.get("description", ""), | |
| "steps": ast.literal_eval(row.get("steps", "[]")) | |
| } | |
| recipes.append(recipe) | |
| except: | |
| continue | |
| return recipes | |
| def search_recipes(self, query, num_results=5, min_rating=3.0): | |
| if not query.strip(): | |
| return [] | |
| print('im here') | |
| tokens = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
| tokens = {k: v.to(self.device) for k, v in tokens.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**tokens) | |
| query_embedding = outputs.last_hidden_state[:, 0, :] | |
| query_embedding = F.normalize(query_embedding, dim=1) | |
| recipe_embeddings = F.normalize(self.recipe_embeddings, dim=1) | |
| similarity_scores = torch.matmul(recipe_embeddings, query_embedding.T).squeeze() | |
| final_scores = [] | |
| for i in range(len(self.recipe_embeddings)): | |
| recipe = self.recipes[i] | |
| avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0)) | |
| if avg_rating < min_rating or num_ratings < 2: | |
| continue | |
| combined_score = ( | |
| 0.6 * similarity_scores[i].item() + | |
| 0.4 * self.recipe_scores.get(recipe["id"], 0) | |
| ) | |
| final_scores.append((combined_score, i)) | |
| top_matches = sorted(final_scores, key=lambda x: x[0], reverse=True)[:num_results] | |
| results = [] | |
| for score, idx in top_matches: | |
| recipe = self.recipes[idx] | |
| avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0)) | |
| results.append({ | |
| "name": recipe["name"], | |
| "tags": recipe.get("tags", []), | |
| "ingredients": recipe.get("ingredients", []), | |
| "minutes": recipe.get("minutes", 0), | |
| "n_steps": recipe.get("n_steps", 0), | |
| "avg_rating": avg_rating, | |
| "num_ratings": num_ratings, | |
| "similarity_score": similarity_scores[idx].item(), | |
| "combined_score": score, | |
| "steps": recipe.get("steps", []), | |
| "description": recipe.get("description", "") | |
| }) | |
| return results | |
| def load_search_system(): | |
| return GoogleDriveRecipeSearch() | |