Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from safetensors.torch import save_file | |
| from transformers import AutoModel | |
| from huggingface_hub import HfApi, create_repo | |
| DEFAULT_MODEL = "nvidia/llama-nemotron-embed-vl-1b-v2" | |
| DEFAULT_DATASET = "rahul7star/food-recipes" | |
| DEFAULT_OUTPUT = "embeddings/all_recipes_image_text_embeddings.safetensors" | |
| _embedding_model = None | |
| def get_hf_token(): | |
| return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| def infer_columns(dataset_name): | |
| """ | |
| Auto-select sensible columns based on known/common recipe datasets. | |
| """ | |
| if dataset_name == "rahul7star/food-recipes": | |
| return "name", "markdown", "image" | |
| return "title", "markdown", "image" | |
| def preview_dataset_columns(dataset_name, split): | |
| try: | |
| dataset = load_dataset(dataset_name, split=split) | |
| columns = list(dataset.column_names) | |
| title_col, text_col, image_col = infer_columns(dataset_name) | |
| sample = dataset[0] | |
| preview = { | |
| "dataset": dataset_name, | |
| "split": split, | |
| "rows": len(dataset), | |
| "columns": columns, | |
| "recommended_mapping": { | |
| "title_column": title_col if title_col in columns else "", | |
| "text_column": text_col if text_col in columns else "", | |
| "image_column": image_col if image_col in columns else "" | |
| }, | |
| "sample": { | |
| k: str(sample[k])[:300] | |
| for k in columns | |
| if k != "image" | |
| } | |
| } | |
| return ( | |
| title_col if title_col in columns else "", | |
| text_col if text_col in columns else "", | |
| image_col if image_col in columns else "", | |
| json.dumps(preview, indent=2, ensure_ascii=False) | |
| ) | |
| except Exception as e: | |
| return "", "", "", f"β Failed to preview dataset: {e}" | |
| def load_embedding_model(model_name): | |
| global _embedding_model | |
| if _embedding_model is not None: | |
| return _embedding_model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ).to(device).eval() | |
| model.processor.p_max_length = 10240 | |
| model.processor.max_input_tiles = 6 | |
| model.processor.use_thumbnail = True | |
| _embedding_model = model | |
| return model | |
| def safe_to_text(value): | |
| if value is None: | |
| return "" | |
| if isinstance(value, list): | |
| return ", ".join(str(x) for x in value) | |
| if isinstance(value, dict): | |
| return json.dumps(value, ensure_ascii=False) | |
| return str(value) | |
| def build_recipe_text(item, title_col, text_col, extra_cols): | |
| title = safe_to_text(item.get(title_col, "")) if title_col else "" | |
| main_text = safe_to_text(item.get(text_col, "")) if text_col else "" | |
| extra_parts = [] | |
| if extra_cols: | |
| for col in extra_cols: | |
| col = col.strip() | |
| if col and col in item: | |
| extra_parts.append(f"{col}: {safe_to_text(item.get(col))}") | |
| return f""" | |
| Recipe Name: | |
| {title} | |
| Recipe Content: | |
| {main_text} | |
| Extra Metadata: | |
| {chr(10).join(extra_parts)} | |
| """.strip() | |
| def get_image(item, image_col): | |
| if not image_col: | |
| return None | |
| return item.get(image_col) | |
| def generate_embeddings_ui( | |
| dataset_name, | |
| split, | |
| title_col, | |
| text_col, | |
| image_col, | |
| extra_cols_text, | |
| output_path, | |
| batch_size, | |
| limit, | |
| upload_to_hf, | |
| repo_id, | |
| repo_type, | |
| hf_path | |
| ): | |
| logs = [] | |
| def log(msg): | |
| print(msg) | |
| logs.append(msg) | |
| try: | |
| if not dataset_name: | |
| return None, "β Please enter a dataset repo." | |
| if not output_path: | |
| output_path = DEFAULT_OUTPUT | |
| log(f"Loading dataset: {dataset_name} | split={split}") | |
| dataset = load_dataset(dataset_name, split=split) | |
| columns = list(dataset.column_names) | |
| log(f"Dataset columns: {columns}") | |
| if title_col and title_col not in columns: | |
| raise ValueError(f"Title column '{title_col}' not found.") | |
| if text_col and text_col not in columns: | |
| raise ValueError(f"Text column '{text_col}' not found.") | |
| if image_col and image_col not in columns: | |
| raise ValueError(f"Image column '{image_col}' not found.") | |
| if limit and int(limit) > 0: | |
| dataset = dataset.select(range(min(int(limit), len(dataset)))) | |
| log(f"Dataset size used: {len(dataset)}") | |
| extra_cols = [ | |
| c.strip() | |
| for c in extra_cols_text.split(",") | |
| if c.strip() | |
| ] if extra_cols_text else [] | |
| model = load_embedding_model(DEFAULT_MODEL) | |
| all_embeddings = [] | |
| total = len(dataset) | |
| batch_size = int(batch_size) | |
| with torch.inference_mode(): | |
| for start in range(0, total, batch_size): | |
| end = min(start + batch_size, total) | |
| batch = dataset[start:end] | |
| texts = [] | |
| images = [] | |
| for i in range(end - start): | |
| item = {k: batch[k][i] for k in batch.keys()} | |
| recipe_text = build_recipe_text( | |
| item=item, | |
| title_col=title_col, | |
| text_col=text_col, | |
| extra_cols=extra_cols | |
| ) | |
| image = get_image(item, image_col) | |
| texts.append(recipe_text) | |
| images.append(image) | |
| log(f"Embedding batch {start} β {end}") | |
| if image_col: | |
| embeddings = model.encode_documents( | |
| texts=texts, | |
| images=images | |
| ) | |
| else: | |
| embeddings = model.encode_documents( | |
| texts=texts | |
| ) | |
| embeddings = embeddings.detach().cpu().float() | |
| all_embeddings.append(embeddings) | |
| final_embeddings = torch.cat(all_embeddings, dim=0) | |
| out_dir = os.path.dirname(output_path) | |
| if out_dir: | |
| os.makedirs(out_dir, exist_ok=True) | |
| save_file( | |
| { | |
| "image_text_embeddings": final_embeddings | |
| }, | |
| output_path | |
| ) | |
| log(f"β Saved embeddings: {output_path}") | |
| log(f"Embedding shape: {tuple(final_embeddings.shape)}") | |
| if upload_to_hf: | |
| token = os.getenv("HF_TOKEN") | |
| api = HfApi(token=os.getenv("HF_TOKEN")) | |
| #api.create_repo(repo, exist_ok=True) | |
| if not token: | |
| log("β HF_TOKEN not found in environment/secrets.") | |
| return output_path, "\n".join(logs) | |
| if not repo_id: | |
| log("β Please enter HF repo ID.") | |
| return output_path, "\n".join(logs) | |
| create_repo( | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| token=token, | |
| exist_ok=True | |
| ) | |
| #api = HfApi(token=token) | |
| if not hf_path: | |
| hf_path = os.path.basename(output_path) | |
| api.upload_file( | |
| path_or_fileobj=output_path, | |
| path_in_repo=hf_path, | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| token=token | |
| ) | |
| log(f"β Uploaded to {repo_id}/{hf_path}") | |
| return output_path, "\n".join(logs) | |
| except Exception as e: | |
| log(f"β Error: {e}") | |
| return None, "\n".join(logs) | |
| css = """ | |
| .gradio-container { | |
| max-width: 1250px !important; | |
| margin: auto !important; | |
| } | |
| .hero { | |
| padding: 30px; | |
| border-radius: 26px; | |
| background: linear-gradient(135deg, #ffffff, #f1f5f9); | |
| border: 1px solid #e2e8f0; | |
| box-shadow: 0 12px 30px rgba(15, 23, 42, 0.06); | |
| margin-bottom: 24px; | |
| } | |
| .hero h1 { | |
| font-size: 38px; | |
| color: #0f172a; | |
| margin-bottom: 8px; | |
| } | |
| .hero p { | |
| font-size: 16px; | |
| color: #475569; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="emerald", secondary_hue="blue"), | |
| css=css, | |
| title="Recipe Embedding Generator" | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="hero"> | |
| <h1>π§ Recipe Embedding Generator</h1> | |
| <p> | |
| Add a Hugging Face recipe dataset, preview columns, generate multimodal embeddings, | |
| save as .safetensors, and optionally upload to a Hugging Face repo. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π¦ Dataset") | |
| dataset_name = gr.Textbox( | |
| label="Hugging Face Dataset Repo", | |
| value=DEFAULT_DATASET, | |
| placeholder="rahul7star/food-recipes" | |
| ) | |
| split = gr.Textbox( | |
| label="Split", | |
| value="train" | |
| ) | |
| preview_btn = gr.Button("π Preview Dataset Columns") | |
| gr.Markdown("## π§© Column Mapping") | |
| title_col = gr.Textbox( | |
| label="Recipe Title Column", | |
| value="name" | |
| ) | |
| text_col = gr.Textbox( | |
| label="Main Recipe Text Column", | |
| value="markdown" | |
| ) | |
| image_col = gr.Textbox( | |
| label="Image Column", | |
| value="image" | |
| ) | |
| extra_cols = gr.Textbox( | |
| label="Extra Columns to Include in Embedding Text", | |
| value="description,tags,steps,minutes,n_ingredients,rating", | |
| placeholder="description,tags,steps,minutes" | |
| ) | |
| gr.Markdown("## βοΈ Settings") | |
| output_path = gr.Textbox( | |
| label="Output SafeTensor Path", | |
| value=DEFAULT_OUTPUT | |
| ) | |
| with gr.Row(): | |
| batch_size = gr.Number( | |
| label="Batch Size", | |
| value=4, | |
| precision=0 | |
| ) | |
| limit = gr.Number( | |
| label="Limit Rows 0 = All", | |
| value=20, | |
| precision=0 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π Dataset Preview") | |
| preview_output = gr.Code( | |
| label="Columns + Sample", | |
| language="json", | |
| lines=20 | |
| ) | |
| gr.Markdown("## βοΈ Upload to Hugging Face") | |
| upload_to_hf = gr.Checkbox( | |
| label="Upload generated embeddings to HF repo", | |
| value=False | |
| ) | |
| repo_id = gr.Textbox( | |
| label="HF Repo ID", | |
| placeholder="rahul7star/embedvector" | |
| ) | |
| repo_type = gr.Dropdown( | |
| label="Repo Type", | |
| choices=["dataset", "model", "space"], | |
| value="dataset" | |
| ) | |
| hf_path = gr.Textbox( | |
| label="Path in Repo", | |
| value="embeddings/all_recipes_image_text_embeddings.safetensors" | |
| ) | |
| run_btn = gr.Button("π Generate Embeddings", variant="primary") | |
| with gr.Row(): | |
| file_output = gr.File(label="Generated SafeTensor File") | |
| logs_output = gr.Textbox(label="Logs", lines=18) | |
| preview_btn.click( | |
| fn=preview_dataset_columns, | |
| inputs=[dataset_name, split], | |
| outputs=[title_col, text_col, image_col, preview_output] | |
| ) | |
| run_btn.click( | |
| fn=generate_embeddings_ui, | |
| inputs=[ | |
| dataset_name, | |
| split, | |
| title_col, | |
| text_col, | |
| image_col, | |
| extra_cols, | |
| output_path, | |
| batch_size, | |
| limit, | |
| upload_to_hf, | |
| repo_id, | |
| repo_type, | |
| hf_path | |
| ], | |
| outputs=[ | |
| file_output, | |
| logs_output | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |