Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from sentence_transformers import SentenceTransformer | |
| import pandas as pd | |
| from datasets import Dataset, load_dataset | |
| import os | |
| import json | |
| import torch | |
| # Global model cache to avoid reloading same model repeatedly | |
| MODELS = {} | |
| REPO_ID = "stevenbucaille/semantic-transformers" | |
| def get_model(model_name): | |
| if model_name not in MODELS: | |
| print(f"Loading model: {model_name}") | |
| MODELS[model_name] = SentenceTransformer(model_name, trust_remote_code=True) | |
| return MODELS[model_name] | |
| def encode_batch_gpu(texts, model_name): | |
| """ | |
| GPU-accelerated function to encode a list of texts. | |
| Takes a list of strings, returns numpy array of embeddings. | |
| """ | |
| print(f"Encoding batch of {len(texts)} items with {model_name}...") | |
| model = get_model(model_name) | |
| device = model.device | |
| # Adjust internal batch size for the model.encode method | |
| internal_batch_size = 512 if device.type == "cuda" else 4 | |
| embeddings = model.encode( | |
| texts, | |
| batch_size=internal_batch_size, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| ) | |
| return embeddings | |
| def process_dataset(model_name, progress=gr.Progress()): | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| return None, "Error: HF_TOKEN environment variable is not set." | |
| try: | |
| # 1. Load Data | |
| progress(0.1, desc="Loading Dataset from Hub...") | |
| print(f"Loading dataset {REPO_ID}...") | |
| try: | |
| ds = load_dataset(REPO_ID, split="train") | |
| df = ds.to_pandas() | |
| except Exception as e: | |
| return None, f"Error loading dataset: {e}. Make sure it exists." | |
| print(f"Loaded {len(df)} rows.") | |
| # 2. Check/Init Embeddings | |
| if "embedding" not in df.columns: | |
| print("Initializing embedding column...") | |
| df["embedding"] = None | |
| df["embedding_model"] = None | |
| # Ensure embedding column allows objects (arrays) or None | |
| if df["embedding"].dtype != "object": | |
| df["embedding"] = df["embedding"].astype("object") | |
| # 3. Find Unprocessed | |
| # We process rows where embedding is None | |
| unprocessed_mask = df["embedding"].isnull() | |
| unprocessed_indices = df[unprocessed_mask].index.tolist() | |
| total_unprocessed = len(unprocessed_indices) | |
| print(f"Total unprocessed rows: {total_unprocessed}") | |
| if total_unprocessed == 0: | |
| return None, "Dataset is already fully embedded!" | |
| # 4. Processing Loop | |
| # We iterate in chunks. If GPU timeout happens, we catch it and save progress. | |
| # User requested max 10k per call. Let's use 5k to be safe with 120s limit. | |
| CHUNK_SIZE = 5000 | |
| processed_count = 0 | |
| error_occurred = False | |
| progress(0.2, desc=f"Starting processing of {total_unprocessed} rows...") | |
| for i in range(0, total_unprocessed, CHUNK_SIZE): | |
| batch_indices = unprocessed_indices[i : i + CHUNK_SIZE] | |
| batch_texts = df.loc[batch_indices, "content"].tolist() | |
| current_progress = 0.2 + 0.7 * (i / total_unprocessed) | |
| progress( | |
| current_progress, desc=f"Encoding batch {i}/{total_unprocessed}..." | |
| ) | |
| try: | |
| # Call GPU function | |
| # This call is protected by @spaces.GPU timeout | |
| embeddings = encode_batch_gpu(batch_texts, model_name) | |
| # Update DataFrame | |
| # Use explicit loop to avoid "Must have equal len keys and value" error | |
| # when assigning list of arrays to pandas slice | |
| for idx, emb in zip(batch_indices, embeddings): | |
| df.at[idx, "embedding"] = emb | |
| df.at[idx, "embedding_model"] = model_name | |
| processed_count += len(batch_indices) | |
| # --- Checkpoint Saving --- | |
| print( | |
| f"Batch completed. Saving checkpoint for {processed_count} processed rows..." | |
| ) | |
| # Save locally first (fast) | |
| df.to_parquet("embeddings_checkpoint.parquet") | |
| # Push to Hub (slower but persistent across machines) | |
| if hf_token and REPO_ID: | |
| try: | |
| # Convert only if necessary or optimize | |
| # Creating a new dataset every time might apply memory pressure | |
| # but it is what ensures the Hub is up to date | |
| temp_ds = Dataset.from_pandas(df) | |
| temp_ds.push_to_hub(REPO_ID, token=hf_token) | |
| print("Checkpoint pushed to Hub.") | |
| del temp_ds | |
| except Exception as hub_err: | |
| print(f"Warning: Failed to push checkpoint to Hub: {hub_err}") | |
| except Exception as e: | |
| print(f"Error during GPU encoding batch {i}: {e}") | |
| error_occurred = True | |
| # We stop processing but proceed to save what we have | |
| break | |
| # 5. Save & Push | |
| progress(0.95, desc="Saving progress to Hub...") | |
| output_msg = f"Processed {processed_count} rows out of {total_unprocessed}.\n" | |
| if error_occurred: | |
| output_msg += "⚠️ Run interrupted (timeout/error). Saving progress...\n" | |
| output_msg += "Please click 'Generate' again to continue." | |
| else: | |
| output_msg += "✅ All batches completed successfully." | |
| try: | |
| # Convert back to Dataset | |
| updated_ds = Dataset.from_pandas(df) | |
| updated_ds.push_to_hub(REPO_ID, token=hf_token) | |
| output_msg += f"\nDataset saved to {REPO_ID}" | |
| except Exception as e: | |
| output_msg += f"\n❌ Error saving to Hub: {e}" | |
| # Optional: Save parquet locally too | |
| output_file = "embeddings_partial.parquet" | |
| try: | |
| df.to_parquet(output_file) | |
| except: | |
| pass | |
| return output_file, output_msg | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Global Error: {str(e)}" | |
| # UI | |
| with gr.Blocks(title="Code Embedding Generator") as demo: | |
| gr.Markdown("# 🚀 ZeroGPU Code Embedding Generator") | |
| gr.Markdown( | |
| f"Generates embeddings for **{REPO_ID}**. <br>" | |
| "If the process times out, successfull batches are saved. **Run again to resume.**" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=[ | |
| "Snowflake/snowflake-arctic-embed-m", | |
| "BAAI/bge-m3", | |
| "sentence-transformers/all-MiniLM-L6-v2", | |
| ], | |
| value="Snowflake/snowflake-arctic-embed-m", | |
| label="Embedding Model", | |
| ) | |
| submit_btn = gr.Button("Generate Embeddings (Resume)", variant="primary") | |
| with gr.Column(): | |
| output_file = gr.File(label="Download Parquet (Partial/Full)") | |
| status_output = gr.Textbox(label="Status Log", lines=10) | |
| submit_btn.click( | |
| fn=process_dataset, | |
| inputs=[model_selector], | |
| outputs=[output_file, status_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |