import os import json import torch import gradio as gr from datasets import load_dataset from safetensors.torch import save_file from transformers import AutoModel from huggingface_hub import HfApi, create_repo DEFAULT_MODEL = "nvidia/llama-nemotron-embed-vl-1b-v2" DEFAULT_DATASET = "rahul7star/food-recipes" DEFAULT_OUTPUT = "embeddings/all_recipes_image_text_embeddings.safetensors" _embedding_model = None def get_hf_token(): return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") def infer_columns(dataset_name): """ Auto-select sensible columns based on known/common recipe datasets. """ if dataset_name == "rahul7star/food-recipes": return "name", "markdown", "image" return "title", "markdown", "image" def preview_dataset_columns(dataset_name, split): try: dataset = load_dataset(dataset_name, split=split) columns = list(dataset.column_names) title_col, text_col, image_col = infer_columns(dataset_name) sample = dataset[0] preview = { "dataset": dataset_name, "split": split, "rows": len(dataset), "columns": columns, "recommended_mapping": { "title_column": title_col if title_col in columns else "", "text_column": text_col if text_col in columns else "", "image_column": image_col if image_col in columns else "" }, "sample": { k: str(sample[k])[:300] for k in columns if k != "image" } } return ( title_col if title_col in columns else "", text_col if text_col in columns else "", image_col if image_col in columns else "", json.dumps(preview, indent=2, ensure_ascii=False) ) except Exception as e: return "", "", "", f"❌ Failed to preview dataset: {e}" def load_embedding_model(model_name): global _embedding_model if _embedding_model is not None: return _embedding_model device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModel.from_pretrained( model_name, torch_dtype=dtype, trust_remote_code=True, low_cpu_mem_usage=True, ).to(device).eval() model.processor.p_max_length = 10240 model.processor.max_input_tiles = 6 model.processor.use_thumbnail = True _embedding_model = model return model def safe_to_text(value): if value is None: return "" if isinstance(value, list): return ", ".join(str(x) for x in value) if isinstance(value, dict): return json.dumps(value, ensure_ascii=False) return str(value) def build_recipe_text(item, title_col, text_col, extra_cols): title = safe_to_text(item.get(title_col, "")) if title_col else "" main_text = safe_to_text(item.get(text_col, "")) if text_col else "" extra_parts = [] if extra_cols: for col in extra_cols: col = col.strip() if col and col in item: extra_parts.append(f"{col}: {safe_to_text(item.get(col))}") return f""" Recipe Name: {title} Recipe Content: {main_text} Extra Metadata: {chr(10).join(extra_parts)} """.strip() def get_image(item, image_col): if not image_col: return None return item.get(image_col) def generate_embeddings_ui( dataset_name, split, title_col, text_col, image_col, extra_cols_text, output_path, batch_size, limit, upload_to_hf, repo_id, repo_type, hf_path ): logs = [] def log(msg): print(msg) logs.append(msg) try: if not dataset_name: return None, "❌ Please enter a dataset repo." if not output_path: output_path = DEFAULT_OUTPUT log(f"Loading dataset: {dataset_name} | split={split}") dataset = load_dataset(dataset_name, split=split) columns = list(dataset.column_names) log(f"Dataset columns: {columns}") if title_col and title_col not in columns: raise ValueError(f"Title column '{title_col}' not found.") if text_col and text_col not in columns: raise ValueError(f"Text column '{text_col}' not found.") if image_col and image_col not in columns: raise ValueError(f"Image column '{image_col}' not found.") if limit and int(limit) > 0: dataset = dataset.select(range(min(int(limit), len(dataset)))) log(f"Dataset size used: {len(dataset)}") extra_cols = [ c.strip() for c in extra_cols_text.split(",") if c.strip() ] if extra_cols_text else [] model = load_embedding_model(DEFAULT_MODEL) all_embeddings = [] total = len(dataset) batch_size = int(batch_size) with torch.inference_mode(): for start in range(0, total, batch_size): end = min(start + batch_size, total) batch = dataset[start:end] texts = [] images = [] for i in range(end - start): item = {k: batch[k][i] for k in batch.keys()} recipe_text = build_recipe_text( item=item, title_col=title_col, text_col=text_col, extra_cols=extra_cols ) image = get_image(item, image_col) texts.append(recipe_text) images.append(image) log(f"Embedding batch {start} → {end}") if image_col: embeddings = model.encode_documents( texts=texts, images=images ) else: embeddings = model.encode_documents( texts=texts ) embeddings = embeddings.detach().cpu().float() all_embeddings.append(embeddings) final_embeddings = torch.cat(all_embeddings, dim=0) out_dir = os.path.dirname(output_path) if out_dir: os.makedirs(out_dir, exist_ok=True) save_file( { "image_text_embeddings": final_embeddings }, output_path ) log(f"✅ Saved embeddings: {output_path}") log(f"Embedding shape: {tuple(final_embeddings.shape)}") if upload_to_hf: token = os.getenv("HF_TOKEN") api = HfApi(token=os.getenv("HF_TOKEN")) #api.create_repo(repo, exist_ok=True) if not token: log("❌ HF_TOKEN not found in environment/secrets.") return output_path, "\n".join(logs) if not repo_id: log("❌ Please enter HF repo ID.") return output_path, "\n".join(logs) create_repo( repo_id=repo_id, repo_type=repo_type, token=token, exist_ok=True ) #api = HfApi(token=token) if not hf_path: hf_path = os.path.basename(output_path) api.upload_file( path_or_fileobj=output_path, path_in_repo=hf_path, repo_id=repo_id, repo_type=repo_type, token=token ) log(f"✅ Uploaded to {repo_id}/{hf_path}") return output_path, "\n".join(logs) except Exception as e: log(f"❌ Error: {e}") return None, "\n".join(logs) css = """ .gradio-container { max-width: 1250px !important; margin: auto !important; } .hero { padding: 30px; border-radius: 26px; background: linear-gradient(135deg, #ffffff, #f1f5f9); border: 1px solid #e2e8f0; box-shadow: 0 12px 30px rgba(15, 23, 42, 0.06); margin-bottom: 24px; } .hero h1 { font-size: 38px; color: #0f172a; margin-bottom: 8px; } .hero p { font-size: 16px; color: #475569; } """ with gr.Blocks( theme=gr.themes.Soft(primary_hue="emerald", secondary_hue="blue"), css=css, title="Recipe Embedding Generator" ) as demo: gr.HTML("""

🧠 Recipe Embedding Generator

Add a Hugging Face recipe dataset, preview columns, generate multimodal embeddings, save as .safetensors, and optionally upload to a Hugging Face repo.

""") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 📦 Dataset") dataset_name = gr.Textbox( label="Hugging Face Dataset Repo", value=DEFAULT_DATASET, placeholder="rahul7star/food-recipes" ) split = gr.Textbox( label="Split", value="train" ) preview_btn = gr.Button("🔍 Preview Dataset Columns") gr.Markdown("## 🧩 Column Mapping") title_col = gr.Textbox( label="Recipe Title Column", value="name" ) text_col = gr.Textbox( label="Main Recipe Text Column", value="markdown" ) image_col = gr.Textbox( label="Image Column", value="image" ) extra_cols = gr.Textbox( label="Extra Columns to Include in Embedding Text", value="description,tags,steps,minutes,n_ingredients,rating", placeholder="description,tags,steps,minutes" ) gr.Markdown("## ⚙️ Settings") output_path = gr.Textbox( label="Output SafeTensor Path", value=DEFAULT_OUTPUT ) with gr.Row(): batch_size = gr.Number( label="Batch Size", value=4, precision=0 ) limit = gr.Number( label="Limit Rows 0 = All", value=20, precision=0 ) with gr.Column(scale=1): gr.Markdown("## 📋 Dataset Preview") preview_output = gr.Code( label="Columns + Sample", language="json", lines=20 ) gr.Markdown("## ☁️ Upload to Hugging Face") upload_to_hf = gr.Checkbox( label="Upload generated embeddings to HF repo", value=False ) repo_id = gr.Textbox( label="HF Repo ID", placeholder="rahul7star/embedvector" ) repo_type = gr.Dropdown( label="Repo Type", choices=["dataset", "model", "space"], value="dataset" ) hf_path = gr.Textbox( label="Path in Repo", value="embeddings/all_recipes_image_text_embeddings.safetensors" ) run_btn = gr.Button("🚀 Generate Embeddings", variant="primary") with gr.Row(): file_output = gr.File(label="Generated SafeTensor File") logs_output = gr.Textbox(label="Logs", lines=18) preview_btn.click( fn=preview_dataset_columns, inputs=[dataset_name, split], outputs=[title_col, text_col, image_col, preview_output] ) run_btn.click( fn=generate_embeddings_ui, inputs=[ dataset_name, split, title_col, text_col, image_col, extra_cols, output_path, batch_size, limit, upload_to_hf, repo_id, repo_type, hf_path ], outputs=[ file_output, logs_output ] ) if __name__ == "__main__": demo.queue() demo.launch()