Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import re | |
| import json | |
| import time | |
| import pandas as pd | |
| import labelbox | |
| def validate_dataset_name(name): | |
| """Validate the dataset name.""" | |
| # Check length | |
| if len(name) > 256: | |
| return "Dataset name should be limited to 256 characters." | |
| # Check allowed characters | |
| allowed_characters_pattern = re.compile(r'^[A-Za-z0-9 _\-.,()\/]+$') | |
| if not allowed_characters_pattern.match(name): | |
| return ("Dataset name can only contain letters, numbers, spaces, and the following punctuation symbols: _-.,()/. Other characters are not supported.") | |
| return None | |
| def create_new_dataset_labelbox (new_dataset_name): | |
| client = labelbox.Client(api_key=labelbox_api_key) | |
| dataset_name = new_dataset_name | |
| dataset = client.create_dataset(name=dataset_name) | |
| dataset_id = dataset.uid | |
| return dataset_id | |
| def get_dataset_from_labelbox(labelbox_api_key): | |
| client = labelbox.Client(api_key=labelbox_api_key) | |
| datasets = client.get_datasets() | |
| return datasets | |
| def destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key): | |
| DOMAIN = f"https://{domain}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| headers = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Destroy context | |
| destroy_payload = { | |
| "clusterId": cluster_id, | |
| "contextId": context_id | |
| } | |
| destroy_response = requests.post( | |
| f"{DOMAIN}/api/1.2/contexts/destroy", | |
| headers=headers, | |
| data=json.dumps(destroy_payload) | |
| ) | |
| if destroy_response.status_code != 200: | |
| raise ValueError("Failed to destroy context.") | |
| def execute_databricks_query(query, cluster_id, domain, databricks_api_key): | |
| DOMAIN = f"https://{domain}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| headers = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Create context | |
| context_payload = { | |
| "clusterId": cluster_id, | |
| "language": "sql" | |
| } | |
| context_response = requests.post( | |
| f"{DOMAIN}/api/1.2/contexts/create", | |
| headers=headers, | |
| data=json.dumps(context_payload) | |
| ) | |
| context_response_data = context_response.json() | |
| if 'id' not in context_response_data: | |
| raise ValueError("Failed to create context.") | |
| context_id = context_response_data['id'] | |
| # Execute query | |
| command_payload = { | |
| "clusterId": cluster_id, | |
| "contextId": context_id, | |
| "language": "sql", | |
| "command": query | |
| } | |
| command_response = requests.post( | |
| f"{DOMAIN}/api/1.2/commands/execute", | |
| headers=headers, | |
| data=json.dumps(command_payload) | |
| ).json() | |
| if 'id' not in command_response: | |
| raise ValueError("Failed to execute command.") | |
| command_id = command_response['id'] | |
| # Wait for the command to complete | |
| while True: | |
| status_response = requests.get( | |
| f"{DOMAIN}/api/1.2/commands/status", | |
| headers=headers, | |
| params={ | |
| "clusterId": cluster_id, | |
| "contextId": context_id, | |
| "commandId": command_id | |
| } | |
| ).json() | |
| command_status = status_response.get("status") | |
| if command_status == "Finished": | |
| break | |
| elif command_status in ["Error", "Cancelled"]: | |
| raise ValueError(f"Command {command_status}. Reason: {status_response.get('results', {}).get('summary')}") | |
| else: | |
| time.sleep(1) # Wait for 5 seconds before checking again | |
| # Convert the results into a pandas DataFrame | |
| data = status_response.get('results', {}).get('data', []) | |
| columns = [col['name'] for col in status_response.get('results', {}).get('schema', [])] | |
| df = pd.DataFrame(data, columns=columns) | |
| destroy_databricks_context(cluster_id, context_id, domain, databricks_api_key) | |
| return df | |
| st.title("Labelbox π€ Databricks") | |
| st.header("Pipeline Creator", divider='rainbow') | |
| def is_valid_url_or_uri(value): | |
| """Check if the provided value is a valid URL or URI.""" | |
| # Check general URLs | |
| url_pattern = re.compile( | |
| r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' | |
| ) | |
| # Check general URIs including cloud storage URIs (like gs://, s3://, etc.) | |
| uri_pattern = re.compile( | |
| r'^(?:[a-z][a-z0-9+.-]*:|/)(?:/?[^\s]*)?$|^(gs|s3|azure|blob)://[^\s]+' | |
| ) | |
| return url_pattern.match(value) or uri_pattern.match(value) | |
| is_preview = st.toggle('Run in Preview Mode', value=False) | |
| if is_preview: | |
| st.success('Running in Preview mode!', icon="β ") | |
| else: | |
| st.success('Running in Production mode!', icon="β ") | |
| st.subheader("Tell us about your Databricks and Labelbox environments", divider='grey') | |
| title = st.text_input('Enter Databricks Domain (e.g., 3980281744248452.2.gcp.databricks.com)', '') | |
| databricks_api_key = st.text_input('Databricks API Key', type='password') | |
| labelbox_api_key = st.text_input('Labelbox API Key', type='password') | |
| # After Labelbox API key is entered | |
| if labelbox_api_key: | |
| # Fetching datasets | |
| datasets = get_dataset_from_labelbox(labelbox_api_key) | |
| create_new_dataset = st.toggle("Make me a new dataset", value=False) | |
| if not create_new_dataset: | |
| # The existing logic for selecting datasets goes here. | |
| dataset_name_to_id = {dataset.name: dataset.uid for dataset in datasets} | |
| selected_dataset_name = st.selectbox("Select an existing dataset:", list(dataset_name_to_id.keys())) | |
| dataset_id = dataset_name_to_id[selected_dataset_name] | |
| else: | |
| # If user toggles "make me a new dataset" | |
| new_dataset_name = st.text_input("Enter the new dataset name:") | |
| # Check if the name is valid | |
| if new_dataset_name: | |
| validation_message = validate_dataset_name(new_dataset_name) | |
| if validation_message: | |
| st.error(validation_message, icon="π«") | |
| else: | |
| st.success(f"Valid dataset name! Dataset_id", icon="β ") | |
| dataset_name = new_dataset_name | |
| # Define the variables beforehand with default values (if not defined) | |
| new_dataset_name = new_dataset_name if 'new_dataset_name' in locals() else None | |
| selected_dataset_name = selected_dataset_name if 'selected_dataset_name' in locals() else None | |
| if new_dataset_name or selected_dataset_name: | |
| # Handling various formats of input | |
| formatted_title = re.sub(r'^https?://', '', title) # Remove http:// or https:// | |
| formatted_title = re.sub(r'/$', '', formatted_title) # Remove trailing slash if present | |
| if formatted_title: | |
| st.subheader("Select and existing cluster or make a new one", divider='grey', help="Jobs in preview mode will use all purpose compute clusters to help you itersate faster. Jobs in production mode will use job clusters to reduce DBUs consumed.") | |
| DOMAIN = f"https://{formatted_title}" | |
| TOKEN = f"Bearer {databricks_api_key}" | |
| HEADERS = { | |
| "Authorization": TOKEN, | |
| "Content-Type": "application/json", | |
| } | |
| # Endpoint to list clusters | |
| ENDPOINT = "/api/2.0/clusters/list" | |
| try: | |
| response = requests.get(DOMAIN + ENDPOINT, headers=HEADERS) | |
| response.raise_for_status() | |
| # Include clusters with cluster_source "UI" or "API" | |
| clusters = response.json().get("clusters", []) | |
| cluster_dict = { | |
| cluster["cluster_name"]: cluster["cluster_id"] | |
| for cluster in clusters if cluster.get("cluster_source") in ["UI", "API"] | |
| } | |
| # Display dropdown with cluster names | |
| make_cluster = st.toggle('Make me a new cluster', value=False) | |
| if make_cluster: | |
| #make a cluster | |
| st.write("Making a new cluster") | |
| else: | |
| if cluster_dict: | |
| selected_cluster_name = st.selectbox( | |
| 'Select a cluster to run on', | |
| list(cluster_dict.keys()), | |
| key='unique_key_for_cluster_selectbox', | |
| index=None, | |
| placeholder="Select a cluster..", | |
| ) | |
| if selected_cluster_name: | |
| cluster_id = cluster_dict[selected_cluster_name] | |
| else: | |
| st.write("No UI or API-based compute clusters found.") | |
| except requests.RequestException as e: | |
| st.write(f"Error communicating with Databricks API: {str(e)}") | |
| except ValueError: | |
| st.write("Received unexpected response from Databricks API.") | |
| if selected_cluster_name and cluster_id: | |
| # Check if the selected cluster is running | |
| cluster_state = [cluster["state"] for cluster in clusters if cluster["cluster_id"] == cluster_id][0] | |
| # If the cluster is not running, start it | |
| if cluster_state != "RUNNING": | |
| with st.spinner("Starting the selected cluster. This typically takes 10 minutes. Please wait..."): | |
| start_response = requests.post(f"{DOMAIN}/api/2.0/clusters/start", headers=HEADERS, json={"cluster_id": cluster_id}) | |
| start_response.raise_for_status() | |
| # Poll until the cluster is up or until timeout | |
| start_time = time.time() | |
| timeout = 1200 # 20 minutes in seconds | |
| while True: | |
| cluster_response = requests.get(f"{DOMAIN}/api/2.0/clusters/get", headers=HEADERS, params={"cluster_id": cluster_id}).json() | |
| if "state" in cluster_response: | |
| if cluster_response["state"] == "RUNNING": | |
| break | |
| elif cluster_response["state"] in ["TERMINATED", "ERROR"]: | |
| st.write(f"Error starting cluster. Current state: {cluster_response['state']}") | |
| break | |
| if (time.time() - start_time) > timeout: | |
| st.write("Timeout reached while starting the cluster.") | |
| break | |
| time.sleep(10) # Check every 10 seconds | |
| st.success(f"{selected_cluster_name} is now running!", icon="πββοΈ") | |
| else: | |
| st.success(f"{selected_cluster_name} is already running!", icon="πββοΈ") | |
| def generate_cron_expression(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
| """ | |
| Generate a cron expression based on the provided frequency and time. | |
| """ | |
| if freq == "1 minute": | |
| return "0 * * * * ?" | |
| elif freq == "1 hour": | |
| return f"0 {minute} * * * ?" | |
| elif freq == "1 day": | |
| return f"0 {minute} {hour} * * ?" | |
| elif freq == "1 week": | |
| if not day_of_week: | |
| raise ValueError("Day of week not provided for weekly frequency.") | |
| return f"0 {minute} {hour} ? * {day_of_week}" | |
| elif freq == "1 month": | |
| if not day_of_month: | |
| raise ValueError("Day of month not provided for monthly frequency.") | |
| return f"0 {minute} {hour} {day_of_month} * ?" | |
| else: | |
| raise ValueError("Invalid frequency provided") | |
| # Streamlit UI | |
| st.subheader("Run Frequency", divider='grey') | |
| # Dropdown to select frequency | |
| freq_options = ["1 minute", "1 hour", "1 day", "1 week", "1 month"] | |
| selected_freq = st.selectbox("Select frequency:", freq_options, placeholder="Select frequency..") | |
| day_of_week = None | |
| day_of_month = None | |
| # If the frequency is hourly, daily, weekly, or monthly, ask for a specific time | |
| if selected_freq != "1 minute": | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| hour = st.selectbox("Hour:", list(range(0, 24))) | |
| with col2: | |
| minute = st.selectbox("Minute:", list(range(0, 60))) | |
| if selected_freq == "1 week": | |
| days_options = ["MON", "TUE", "WED", "THU", "FRI", "SAT", "SUN"] | |
| day_of_week = st.selectbox("Select day of the week:", days_options) | |
| elif selected_freq == "1 month": | |
| day_of_month = st.selectbox("Select day of the month:", list(range(1, 32))) | |
| else: | |
| hour, minute = 0, 0 | |
| # Generate the cron expression | |
| frequency = generate_cron_expression(selected_freq, hour, minute, day_of_week, day_of_month) | |
| def generate_human_readable_message(freq, hour=0, minute=0, day_of_week=None, day_of_month=None): | |
| """ | |
| Generate a human-readable message for the scheduling. | |
| """ | |
| if freq == "1 minute": | |
| return "Job will run every minute." | |
| elif freq == "1 hour": | |
| return f"Job will run once an hour at minute {minute}." | |
| elif freq == "1 day": | |
| return f"Job will run daily at {hour:02}:{minute:02}." | |
| elif freq == "1 week": | |
| if not day_of_week: | |
| raise ValueError("Day of week not provided for weekly frequency.") | |
| return f"Job will run every {day_of_week} at {hour:02}:{minute:02}." | |
| elif freq == "1 month": | |
| if not day_of_month: | |
| raise ValueError("Day of month not provided for monthly frequency.") | |
| return f"Job will run once a month on day {day_of_month} at {hour:02}:{minute:02}." | |
| else: | |
| raise ValueError("Invalid frequency provided") | |
| # Generate the human-readable message | |
| readable_msg = generate_human_readable_message(selected_freq, hour, minute, day_of_week, day_of_month) | |
| if frequency: | |
| st.success(readable_msg, icon="π ") | |
| st.subheader("Select a table", divider="grey") | |
| with st.spinner('Querying Databricks...'): | |
| query = "SHOW DATABASES;" | |
| result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| # Extract the databaseName values from the DataFrame | |
| database_names = result_data['databaseName'].tolist() | |
| # Create a dropdown with the database names | |
| selected_database = st.selectbox("Select a Database:", database_names, index=None, placeholder="Select a database..") | |
| if selected_database: | |
| with st.spinner('Querying Databricks...'): | |
| query = f"SHOW TABLES IN {selected_database};" | |
| result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| # Extract the tableName values from the DataFrame | |
| table_names = result_data['tableName'].tolist() | |
| # Create a dropdown with the database names | |
| selected_table = st.selectbox("Select a Table:", table_names, index=None, placeholder="Select a table..") | |
| if selected_table: | |
| with st.spinner('Querying Databricks...'): | |
| query = f"SHOW COLUMNS IN {selected_database}.{selected_table};" | |
| result_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| column_names = result_data['col_name'].tolist() | |
| st.subheader("Map table schema to Labelbox schema", divider="grey") | |
| # Your existing code to handle schema mapping... | |
| # Fetch the first 5 rows of the selected table | |
| with st.spinner('Fetching first 5 rows of the selected table...'): | |
| query = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 5;" | |
| table_sample_data = execute_databricks_query(query, cluster_id, formatted_title, databricks_api_key) | |
| # Display the sample data in the Streamlit UI | |
| st.write(table_sample_data) | |
| # Define two columns for side-by-side selectboxes | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| selected_row_data = st.selectbox( | |
| "row_data (required):", | |
| column_names, | |
| index=None, | |
| placeholder="Select a column..", | |
| help="Select the column that contains the URL/URI bucket location of the data rows you wish to import into Labelbox." | |
| ) | |
| with col2: | |
| selected_global_key = st.selectbox( | |
| "global_key (optional):", | |
| column_names, | |
| index=None, | |
| placeholder="Select a column..", | |
| help="Select the column that contains the global key. If not provided, a new key will be generated for you." | |
| ) | |
| # Fetch a single row from the selected table | |
| query_sample_row = f"SELECT * FROM {selected_database}.{selected_table} LIMIT 1;" | |
| result_sample = execute_databricks_query(query_sample_row, cluster_id, formatted_title, databricks_api_key) | |
| if selected_row_data: | |
| # Extract the value from the selected row_data column | |
| sample_row_data_value = result_sample[selected_row_data].iloc[0] | |
| # Validate the extracted value | |
| if is_valid_url_or_uri(sample_row_data_value): | |
| st.success(f"Sample URI/URL from selected row data column: {sample_row_data_value}", icon="β ") | |
| dataset_id = create_new_dataset_labelbox(new_dataset_name) if create_new_dataset else dataset_id | |
| # Mode | |
| mode = "preview" if is_preview else "production" | |
| # Databricks instance and API key | |
| databricks_instance = formatted_title | |
| databricks_api_key = databricks_api_key | |
| # Dataset ID and New Dataset | |
| new_dataset = 1 if create_new_dataset else 0 | |
| dataset_id = dataset_id | |
| # Table Path | |
| table_path = f"{selected_database}.{selected_table}" | |
| # Frequency | |
| frequency = frequency | |
| # Cluster ID and New Cluster | |
| new_cluster = 1 if make_cluster else 0 | |
| cluster_id = cluster_id if not make_cluster else "" | |
| # Schema Map | |
| row_data_input = selected_row_data | |
| global_key_input = selected_global_key | |
| schema_map_dict = {'row_data': row_data_input} | |
| if global_key_input: | |
| schema_map_dict['global_key'] = global_key_input | |
| # Convert the dict to a stringified JSON | |
| schema_map_str = json.dumps(schema_map_dict) | |
| data = { | |
| "mode": mode, | |
| "databricks_instance": databricks_instance, | |
| "databricks_api_key": databricks_api_key, | |
| "new_dataset": new_dataset, | |
| "dataset_id": dataset_id, | |
| "table_path": table_path, | |
| "labelbox_api_key": labelbox_api_key, | |
| "frequency": frequency, | |
| "new_cluster": new_cluster, | |
| "cluster_id": cluster_id, | |
| "schema_map": schema_map_str | |
| } | |
| # Display the constructed data using Streamlit | |
| st.json(data) | |
| if st.button("Deploy Pipeline!", type="primary"): | |
| # Ensure all fields are filled out | |
| required_fields = [ | |
| mode, databricks_instance, databricks_api_key, new_dataset, dataset_id, | |
| table_path, labelbox_api_key, frequency, new_cluster, cluster_id, schema_map_str | |
| ] | |
| # Sending a POST request to the Flask app endpoint | |
| with st.spinner("Deploying pipeline..."): | |
| response = requests.post("http://127.0.0.1:5000/create-databricks-job", json=data) | |
| # Check if request was successful | |
| if response.status_code == 200: | |
| # Display the response using Streamlit | |
| st.balloons() | |
| st.success("Pipeline deployed successfully!", icon="π") | |
| st.json(response.json()) | |
| else: | |
| st.error(f"Failed to deploy pipeline. Response: {response.text}", icon="π«") | |
| else: | |
| st.error(f"row_data '{sample_row_data_value}' is not a valid URI or URL. Please select a different column.", icon="π«") | |