rahul7star's picture
Update app.py
258c965 verified
import os
import json
import torch
import gradio as gr
from datasets import load_dataset
from safetensors.torch import save_file
from transformers import AutoModel
from huggingface_hub import HfApi, create_repo
DEFAULT_MODEL = "nvidia/llama-nemotron-embed-vl-1b-v2"
DEFAULT_DATASET = "rahul7star/food-recipes"
DEFAULT_OUTPUT = "embeddings/all_recipes_image_text_embeddings.safetensors"
_embedding_model = None
def get_hf_token():
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
def infer_columns(dataset_name):
"""
Auto-select sensible columns based on known/common recipe datasets.
"""
if dataset_name == "rahul7star/food-recipes":
return "name", "markdown", "image"
return "title", "markdown", "image"
def preview_dataset_columns(dataset_name, split):
try:
dataset = load_dataset(dataset_name, split=split)
columns = list(dataset.column_names)
title_col, text_col, image_col = infer_columns(dataset_name)
sample = dataset[0]
preview = {
"dataset": dataset_name,
"split": split,
"rows": len(dataset),
"columns": columns,
"recommended_mapping": {
"title_column": title_col if title_col in columns else "",
"text_column": text_col if text_col in columns else "",
"image_column": image_col if image_col in columns else ""
},
"sample": {
k: str(sample[k])[:300]
for k in columns
if k != "image"
}
}
return (
title_col if title_col in columns else "",
text_col if text_col in columns else "",
image_col if image_col in columns else "",
json.dumps(preview, indent=2, ensure_ascii=False)
)
except Exception as e:
return "", "", "", f"❌ Failed to preview dataset: {e}"
def load_embedding_model(model_name):
global _embedding_model
if _embedding_model is not None:
return _embedding_model
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModel.from_pretrained(
model_name,
torch_dtype=dtype,
trust_remote_code=True,
low_cpu_mem_usage=True,
).to(device).eval()
model.processor.p_max_length = 10240
model.processor.max_input_tiles = 6
model.processor.use_thumbnail = True
_embedding_model = model
return model
def safe_to_text(value):
if value is None:
return ""
if isinstance(value, list):
return ", ".join(str(x) for x in value)
if isinstance(value, dict):
return json.dumps(value, ensure_ascii=False)
return str(value)
def build_recipe_text(item, title_col, text_col, extra_cols):
title = safe_to_text(item.get(title_col, "")) if title_col else ""
main_text = safe_to_text(item.get(text_col, "")) if text_col else ""
extra_parts = []
if extra_cols:
for col in extra_cols:
col = col.strip()
if col and col in item:
extra_parts.append(f"{col}: {safe_to_text(item.get(col))}")
return f"""
Recipe Name:
{title}
Recipe Content:
{main_text}
Extra Metadata:
{chr(10).join(extra_parts)}
""".strip()
def get_image(item, image_col):
if not image_col:
return None
return item.get(image_col)
def generate_embeddings_ui(
dataset_name,
split,
title_col,
text_col,
image_col,
extra_cols_text,
output_path,
batch_size,
limit,
upload_to_hf,
repo_id,
repo_type,
hf_path
):
logs = []
def log(msg):
print(msg)
logs.append(msg)
try:
if not dataset_name:
return None, "❌ Please enter a dataset repo."
if not output_path:
output_path = DEFAULT_OUTPUT
log(f"Loading dataset: {dataset_name} | split={split}")
dataset = load_dataset(dataset_name, split=split)
columns = list(dataset.column_names)
log(f"Dataset columns: {columns}")
if title_col and title_col not in columns:
raise ValueError(f"Title column '{title_col}' not found.")
if text_col and text_col not in columns:
raise ValueError(f"Text column '{text_col}' not found.")
if image_col and image_col not in columns:
raise ValueError(f"Image column '{image_col}' not found.")
if limit and int(limit) > 0:
dataset = dataset.select(range(min(int(limit), len(dataset))))
log(f"Dataset size used: {len(dataset)}")
extra_cols = [
c.strip()
for c in extra_cols_text.split(",")
if c.strip()
] if extra_cols_text else []
model = load_embedding_model(DEFAULT_MODEL)
all_embeddings = []
total = len(dataset)
batch_size = int(batch_size)
with torch.inference_mode():
for start in range(0, total, batch_size):
end = min(start + batch_size, total)
batch = dataset[start:end]
texts = []
images = []
for i in range(end - start):
item = {k: batch[k][i] for k in batch.keys()}
recipe_text = build_recipe_text(
item=item,
title_col=title_col,
text_col=text_col,
extra_cols=extra_cols
)
image = get_image(item, image_col)
texts.append(recipe_text)
images.append(image)
log(f"Embedding batch {start} β†’ {end}")
if image_col:
embeddings = model.encode_documents(
texts=texts,
images=images
)
else:
embeddings = model.encode_documents(
texts=texts
)
embeddings = embeddings.detach().cpu().float()
all_embeddings.append(embeddings)
final_embeddings = torch.cat(all_embeddings, dim=0)
out_dir = os.path.dirname(output_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
save_file(
{
"image_text_embeddings": final_embeddings
},
output_path
)
log(f"βœ… Saved embeddings: {output_path}")
log(f"Embedding shape: {tuple(final_embeddings.shape)}")
if upload_to_hf:
token = os.getenv("HF_TOKEN")
api = HfApi(token=os.getenv("HF_TOKEN"))
#api.create_repo(repo, exist_ok=True)
if not token:
log("❌ HF_TOKEN not found in environment/secrets.")
return output_path, "\n".join(logs)
if not repo_id:
log("❌ Please enter HF repo ID.")
return output_path, "\n".join(logs)
create_repo(
repo_id=repo_id,
repo_type=repo_type,
token=token,
exist_ok=True
)
#api = HfApi(token=token)
if not hf_path:
hf_path = os.path.basename(output_path)
api.upload_file(
path_or_fileobj=output_path,
path_in_repo=hf_path,
repo_id=repo_id,
repo_type=repo_type,
token=token
)
log(f"βœ… Uploaded to {repo_id}/{hf_path}")
return output_path, "\n".join(logs)
except Exception as e:
log(f"❌ Error: {e}")
return None, "\n".join(logs)
css = """
.gradio-container {
max-width: 1250px !important;
margin: auto !important;
}
.hero {
padding: 30px;
border-radius: 26px;
background: linear-gradient(135deg, #ffffff, #f1f5f9);
border: 1px solid #e2e8f0;
box-shadow: 0 12px 30px rgba(15, 23, 42, 0.06);
margin-bottom: 24px;
}
.hero h1 {
font-size: 38px;
color: #0f172a;
margin-bottom: 8px;
}
.hero p {
font-size: 16px;
color: #475569;
}
"""
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="emerald", secondary_hue="blue"),
css=css,
title="Recipe Embedding Generator"
) as demo:
gr.HTML("""
<div class="hero">
<h1>🧠 Recipe Embedding Generator</h1>
<p>
Add a Hugging Face recipe dataset, preview columns, generate multimodal embeddings,
save as .safetensors, and optionally upload to a Hugging Face repo.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ“¦ Dataset")
dataset_name = gr.Textbox(
label="Hugging Face Dataset Repo",
value=DEFAULT_DATASET,
placeholder="rahul7star/food-recipes"
)
split = gr.Textbox(
label="Split",
value="train"
)
preview_btn = gr.Button("πŸ” Preview Dataset Columns")
gr.Markdown("## 🧩 Column Mapping")
title_col = gr.Textbox(
label="Recipe Title Column",
value="name"
)
text_col = gr.Textbox(
label="Main Recipe Text Column",
value="markdown"
)
image_col = gr.Textbox(
label="Image Column",
value="image"
)
extra_cols = gr.Textbox(
label="Extra Columns to Include in Embedding Text",
value="description,tags,steps,minutes,n_ingredients,rating",
placeholder="description,tags,steps,minutes"
)
gr.Markdown("## βš™οΈ Settings")
output_path = gr.Textbox(
label="Output SafeTensor Path",
value=DEFAULT_OUTPUT
)
with gr.Row():
batch_size = gr.Number(
label="Batch Size",
value=4,
precision=0
)
limit = gr.Number(
label="Limit Rows 0 = All",
value=20,
precision=0
)
with gr.Column(scale=1):
gr.Markdown("## πŸ“‹ Dataset Preview")
preview_output = gr.Code(
label="Columns + Sample",
language="json",
lines=20
)
gr.Markdown("## ☁️ Upload to Hugging Face")
upload_to_hf = gr.Checkbox(
label="Upload generated embeddings to HF repo",
value=False
)
repo_id = gr.Textbox(
label="HF Repo ID",
placeholder="rahul7star/embedvector"
)
repo_type = gr.Dropdown(
label="Repo Type",
choices=["dataset", "model", "space"],
value="dataset"
)
hf_path = gr.Textbox(
label="Path in Repo",
value="embeddings/all_recipes_image_text_embeddings.safetensors"
)
run_btn = gr.Button("πŸš€ Generate Embeddings", variant="primary")
with gr.Row():
file_output = gr.File(label="Generated SafeTensor File")
logs_output = gr.Textbox(label="Logs", lines=18)
preview_btn.click(
fn=preview_dataset_columns,
inputs=[dataset_name, split],
outputs=[title_col, text_col, image_col, preview_output]
)
run_btn.click(
fn=generate_embeddings_ui,
inputs=[
dataset_name,
split,
title_col,
text_col,
image_col,
extra_cols,
output_path,
batch_size,
limit,
upload_to_hf,
repo_id,
repo_type,
hf_path
],
outputs=[
file_output,
logs_output
]
)
if __name__ == "__main__":
demo.queue()
demo.launch()