stevenbucaille's picture
fix
c56dc56
import gradio as gr
import spaces
from sentence_transformers import SentenceTransformer
import pandas as pd
from datasets import Dataset, load_dataset
import os
import json
import torch
# Global model cache to avoid reloading same model repeatedly
MODELS = {}
REPO_ID = "stevenbucaille/semantic-transformers"
def get_model(model_name):
if model_name not in MODELS:
print(f"Loading model: {model_name}")
MODELS[model_name] = SentenceTransformer(model_name, trust_remote_code=True)
return MODELS[model_name]
@spaces.GPU(size="xlarge", duration=120)
def encode_batch_gpu(texts, model_name):
"""
GPU-accelerated function to encode a list of texts.
Takes a list of strings, returns numpy array of embeddings.
"""
print(f"Encoding batch of {len(texts)} items with {model_name}...")
model = get_model(model_name)
device = model.device
# Adjust internal batch size for the model.encode method
internal_batch_size = 512 if device.type == "cuda" else 4
embeddings = model.encode(
texts,
batch_size=internal_batch_size,
show_progress_bar=True,
convert_to_numpy=True,
)
return embeddings
def process_dataset(model_name, progress=gr.Progress()):
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
return None, "Error: HF_TOKEN environment variable is not set."
try:
# 1. Load Data
progress(0.1, desc="Loading Dataset from Hub...")
print(f"Loading dataset {REPO_ID}...")
try:
ds = load_dataset(REPO_ID, split="train")
df = ds.to_pandas()
except Exception as e:
return None, f"Error loading dataset: {e}. Make sure it exists."
print(f"Loaded {len(df)} rows.")
# 2. Check/Init Embeddings
if "embedding" not in df.columns:
print("Initializing embedding column...")
df["embedding"] = None
df["embedding_model"] = None
# Ensure embedding column allows objects (arrays) or None
if df["embedding"].dtype != "object":
df["embedding"] = df["embedding"].astype("object")
# 3. Find Unprocessed
# We process rows where embedding is None
unprocessed_mask = df["embedding"].isnull()
unprocessed_indices = df[unprocessed_mask].index.tolist()
total_unprocessed = len(unprocessed_indices)
print(f"Total unprocessed rows: {total_unprocessed}")
if total_unprocessed == 0:
return None, "Dataset is already fully embedded!"
# 4. Processing Loop
# We iterate in chunks. If GPU timeout happens, we catch it and save progress.
# User requested max 10k per call. Let's use 5k to be safe with 120s limit.
CHUNK_SIZE = 5000
processed_count = 0
error_occurred = False
progress(0.2, desc=f"Starting processing of {total_unprocessed} rows...")
for i in range(0, total_unprocessed, CHUNK_SIZE):
batch_indices = unprocessed_indices[i : i + CHUNK_SIZE]
batch_texts = df.loc[batch_indices, "content"].tolist()
current_progress = 0.2 + 0.7 * (i / total_unprocessed)
progress(
current_progress, desc=f"Encoding batch {i}/{total_unprocessed}..."
)
try:
# Call GPU function
# This call is protected by @spaces.GPU timeout
embeddings = encode_batch_gpu(batch_texts, model_name)
# Update DataFrame
# Use explicit loop to avoid "Must have equal len keys and value" error
# when assigning list of arrays to pandas slice
for idx, emb in zip(batch_indices, embeddings):
df.at[idx, "embedding"] = emb
df.at[idx, "embedding_model"] = model_name
processed_count += len(batch_indices)
# --- Checkpoint Saving ---
print(
f"Batch completed. Saving checkpoint for {processed_count} processed rows..."
)
# Save locally first (fast)
df.to_parquet("embeddings_checkpoint.parquet")
# Push to Hub (slower but persistent across machines)
if hf_token and REPO_ID:
try:
# Convert only if necessary or optimize
# Creating a new dataset every time might apply memory pressure
# but it is what ensures the Hub is up to date
temp_ds = Dataset.from_pandas(df)
temp_ds.push_to_hub(REPO_ID, token=hf_token)
print("Checkpoint pushed to Hub.")
del temp_ds
except Exception as hub_err:
print(f"Warning: Failed to push checkpoint to Hub: {hub_err}")
except Exception as e:
print(f"Error during GPU encoding batch {i}: {e}")
error_occurred = True
# We stop processing but proceed to save what we have
break
# 5. Save & Push
progress(0.95, desc="Saving progress to Hub...")
output_msg = f"Processed {processed_count} rows out of {total_unprocessed}.\n"
if error_occurred:
output_msg += "⚠️ Run interrupted (timeout/error). Saving progress...\n"
output_msg += "Please click 'Generate' again to continue."
else:
output_msg += "✅ All batches completed successfully."
try:
# Convert back to Dataset
updated_ds = Dataset.from_pandas(df)
updated_ds.push_to_hub(REPO_ID, token=hf_token)
output_msg += f"\nDataset saved to {REPO_ID}"
except Exception as e:
output_msg += f"\n❌ Error saving to Hub: {e}"
# Optional: Save parquet locally too
output_file = "embeddings_partial.parquet"
try:
df.to_parquet(output_file)
except:
pass
return output_file, output_msg
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Global Error: {str(e)}"
# UI
with gr.Blocks(title="Code Embedding Generator") as demo:
gr.Markdown("# 🚀 ZeroGPU Code Embedding Generator")
gr.Markdown(
f"Generates embeddings for **{REPO_ID}**. <br>"
"If the process times out, successfull batches are saved. **Run again to resume.**"
)
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
choices=[
"Snowflake/snowflake-arctic-embed-m",
"BAAI/bge-m3",
"sentence-transformers/all-MiniLM-L6-v2",
],
value="Snowflake/snowflake-arctic-embed-m",
label="Embedding Model",
)
submit_btn = gr.Button("Generate Embeddings (Resume)", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Parquet (Partial/Full)")
status_output = gr.Textbox(label="Status Log", lines=10)
submit_btn.click(
fn=process_dataset,
inputs=[model_selector],
outputs=[output_file, status_output],
)
if __name__ == "__main__":
demo.launch()