import gradio as gr import spaces from sentence_transformers import SentenceTransformer import pandas as pd from datasets import Dataset, load_dataset import os import json import torch # Global model cache to avoid reloading same model repeatedly MODELS = {} REPO_ID = "stevenbucaille/semantic-transformers" def get_model(model_name): if model_name not in MODELS: print(f"Loading model: {model_name}") MODELS[model_name] = SentenceTransformer(model_name, trust_remote_code=True) return MODELS[model_name] @spaces.GPU(size="xlarge", duration=120) def encode_batch_gpu(texts, model_name): """ GPU-accelerated function to encode a list of texts. Takes a list of strings, returns numpy array of embeddings. """ print(f"Encoding batch of {len(texts)} items with {model_name}...") model = get_model(model_name) device = model.device # Adjust internal batch size for the model.encode method internal_batch_size = 512 if device.type == "cuda" else 4 embeddings = model.encode( texts, batch_size=internal_batch_size, show_progress_bar=True, convert_to_numpy=True, ) return embeddings def process_dataset(model_name, progress=gr.Progress()): hf_token = os.getenv("HF_TOKEN") if not hf_token: return None, "Error: HF_TOKEN environment variable is not set." try: # 1. Load Data progress(0.1, desc="Loading Dataset from Hub...") print(f"Loading dataset {REPO_ID}...") try: ds = load_dataset(REPO_ID, split="train") df = ds.to_pandas() except Exception as e: return None, f"Error loading dataset: {e}. Make sure it exists." print(f"Loaded {len(df)} rows.") # 2. Check/Init Embeddings if "embedding" not in df.columns: print("Initializing embedding column...") df["embedding"] = None df["embedding_model"] = None # Ensure embedding column allows objects (arrays) or None if df["embedding"].dtype != "object": df["embedding"] = df["embedding"].astype("object") # 3. Find Unprocessed # We process rows where embedding is None unprocessed_mask = df["embedding"].isnull() unprocessed_indices = df[unprocessed_mask].index.tolist() total_unprocessed = len(unprocessed_indices) print(f"Total unprocessed rows: {total_unprocessed}") if total_unprocessed == 0: return None, "Dataset is already fully embedded!" # 4. Processing Loop # We iterate in chunks. If GPU timeout happens, we catch it and save progress. # User requested max 10k per call. Let's use 5k to be safe with 120s limit. CHUNK_SIZE = 5000 processed_count = 0 error_occurred = False progress(0.2, desc=f"Starting processing of {total_unprocessed} rows...") for i in range(0, total_unprocessed, CHUNK_SIZE): batch_indices = unprocessed_indices[i : i + CHUNK_SIZE] batch_texts = df.loc[batch_indices, "content"].tolist() current_progress = 0.2 + 0.7 * (i / total_unprocessed) progress( current_progress, desc=f"Encoding batch {i}/{total_unprocessed}..." ) try: # Call GPU function # This call is protected by @spaces.GPU timeout embeddings = encode_batch_gpu(batch_texts, model_name) # Update DataFrame # Use explicit loop to avoid "Must have equal len keys and value" error # when assigning list of arrays to pandas slice for idx, emb in zip(batch_indices, embeddings): df.at[idx, "embedding"] = emb df.at[idx, "embedding_model"] = model_name processed_count += len(batch_indices) # --- Checkpoint Saving --- print( f"Batch completed. Saving checkpoint for {processed_count} processed rows..." ) # Save locally first (fast) df.to_parquet("embeddings_checkpoint.parquet") # Push to Hub (slower but persistent across machines) if hf_token and REPO_ID: try: # Convert only if necessary or optimize # Creating a new dataset every time might apply memory pressure # but it is what ensures the Hub is up to date temp_ds = Dataset.from_pandas(df) temp_ds.push_to_hub(REPO_ID, token=hf_token) print("Checkpoint pushed to Hub.") del temp_ds except Exception as hub_err: print(f"Warning: Failed to push checkpoint to Hub: {hub_err}") except Exception as e: print(f"Error during GPU encoding batch {i}: {e}") error_occurred = True # We stop processing but proceed to save what we have break # 5. Save & Push progress(0.95, desc="Saving progress to Hub...") output_msg = f"Processed {processed_count} rows out of {total_unprocessed}.\n" if error_occurred: output_msg += "āš ļø Run interrupted (timeout/error). Saving progress...\n" output_msg += "Please click 'Generate' again to continue." else: output_msg += "āœ… All batches completed successfully." try: # Convert back to Dataset updated_ds = Dataset.from_pandas(df) updated_ds.push_to_hub(REPO_ID, token=hf_token) output_msg += f"\nDataset saved to {REPO_ID}" except Exception as e: output_msg += f"\nāŒ Error saving to Hub: {e}" # Optional: Save parquet locally too output_file = "embeddings_partial.parquet" try: df.to_parquet(output_file) except: pass return output_file, output_msg except Exception as e: import traceback traceback.print_exc() return None, f"Global Error: {str(e)}" # UI with gr.Blocks(title="Code Embedding Generator") as demo: gr.Markdown("# šŸš€ ZeroGPU Code Embedding Generator") gr.Markdown( f"Generates embeddings for **{REPO_ID}**.
" "If the process times out, successfull batches are saved. **Run again to resume.**" ) with gr.Row(): with gr.Column(): model_selector = gr.Dropdown( choices=[ "Snowflake/snowflake-arctic-embed-m", "BAAI/bge-m3", "sentence-transformers/all-MiniLM-L6-v2", ], value="Snowflake/snowflake-arctic-embed-m", label="Embedding Model", ) submit_btn = gr.Button("Generate Embeddings (Resume)", variant="primary") with gr.Column(): output_file = gr.File(label="Download Parquet (Partial/Full)") status_output = gr.Textbox(label="Status Log", lines=10) submit_btn.click( fn=process_dataset, inputs=[model_selector], outputs=[output_file, status_output], ) if __name__ == "__main__": demo.launch()