Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from astropy.io import fits | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import io | |
| from PIL import Image | |
| import astropy.units as u | |
| from astropy.wcs import WCS | |
| from astropy.coordinates import SkyCoord | |
| from astropy import coordinates as coord | |
| from astropy.wcs.utils import skycoord_to_pixel | |
| from astroquery.simbad import Simbad | |
| import pandas as pd | |
| import matplotlib.patches as patches | |
| # Increase the limit (set to a value larger than the pixel count of your image) | |
| Image.MAX_IMAGE_PIXELS = None | |
| plt.style.use('dark_background') | |
| # Initialize globals | |
| global_dataframe = pd.DataFrame() | |
| global_data = None | |
| global_header = None | |
| def show_csv(file): | |
| """ | |
| Displays the uploaded CSV file as a table. | |
| """ | |
| global global_dataframe | |
| try: | |
| # Read the CSV file into a pandas DataFrame | |
| df = pd.read_csv(file.name, index_col=0) | |
| global_dataframe = df # Store the dataframe globally for filtering | |
| # Extract unique types from the "type" column | |
| if "TYPE" in df.columns: | |
| unique_types = df["TYPE"].unique().tolist() | |
| return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True) | |
| else: | |
| return "Error: CSV does not contain a 'type' column.", None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| # Define a function to be called when the button is clicked | |
| def query_update_table(): | |
| """ | |
| Displays the uploaded CSV file as a table. | |
| """ | |
| global global_dataframe, global_header, global_data | |
| try: | |
| # Read the CSV file into a pandas DataFrame | |
| #df = pd.read_csv('dataframe.csv', index_col=0) | |
| Simbad.TIMEOUT = 120 | |
| # Define the specific coordinates | |
| wcs = WCS(global_header).dropaxis(2) | |
| center_ra = global_header['CRVAL1'] | |
| center_dec = global_header['CRVAL2'] | |
| target_coord = SkyCoord(ra=center_ra, dec=center_dec, unit=(u.deg, u.deg), frame='icrs') | |
| print(center_ra, center_dec) | |
| # define the search radius | |
| radius_deg = max([abs(global_header['CDELT1']),abs(global_header['CDELT2'])])*max([global_header['NAXIS1'],global_header['NAXIS2']]) | |
| radius_deg *= 1 | |
| # Set up the query criteria | |
| if target_coord.dec.deg > 0: | |
| custom_query = f"region(CIRCLE, {target_coord.ra.deg} +{target_coord.dec.deg}, {radius_deg}d)" | |
| else: | |
| custom_query = f"region(CIRCLE, {target_coord.ra.deg} {target_coord.dec.deg}, {radius_deg}d)" | |
| print(f'Query={custom_query}') | |
| result_table = Simbad.query_criteria(custom_query, otype='galaxy') | |
| print("received feedback from simbad!!!") | |
| print(result_table) | |
| df = result_table.to_pandas().set_index('main_id') | |
| print(df.columns) | |
| df['Pixel_Position'] = [skycoord_to_pixel(SkyCoord(v[0],v[1], unit=(u.deg, u.deg), frame='icrs'), wcs) for v in df[['ra','dec']].values] | |
| print(df['Pixel_Position']) | |
| df['px'] = df['Pixel_Position'].apply(lambda x: int(x[0])) | |
| df['py'] = df['Pixel_Position'].apply(lambda x: int(x[1])) | |
| mask = (df.px>0)&(df.px< global_data.shape[1])&(df.py>0)&(df.py<global_data.shape[0]) | |
| print(df) | |
| df = df[mask] | |
| df = df.reset_index() | |
| df['TYPE'] = df['main_id'].apply(lambda x: x.split(' ')[0].split('+')[0]) | |
| df = df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True) | |
| print(df) | |
| #df = df.iloc[:200] | |
| global_dataframe = df # Store the dataframe globally for filtering | |
| # Extract unique types from the "type" column | |
| if "TYPE" in df.columns: | |
| unique_types = df["TYPE"].unique().tolist() | |
| return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True) | |
| else: | |
| return "Error: CSV does not contain a 'type' column.", None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| def load_fits_image(file, type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method): | |
| """ | |
| Displays the data from the uploaded FITS file. | |
| """ | |
| global global_header, global_data | |
| # Open the FITS file | |
| hdu = fits.open(file) | |
| data = hdu[0].data # Access the primary HDU data | |
| data = np.swapaxes(np.swapaxes(data,0,2),0,1)#.astype(np.float) | |
| #data = (data*255).astype(np.uint8) # Access the primary HDU data | |
| global_data = data | |
| # get fits header | |
| header = hdu[0].header | |
| global_header = header | |
| #selected_types, title, selected_axis_options, num_rows, patch_size, patch_color, sort_method | |
| return update_images_and_tables(type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method) | |
| def update_images_and_tables(selected_types, title, selected_axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method): | |
| global global_dataframe, global_header, global_data | |
| if selected_types and not global_dataframe.empty: | |
| # Filter the dataframe based on the selected types | |
| filtered_df = global_dataframe[global_dataframe["TYPE"].isin(selected_types)] | |
| mask = (filtered_df.px-patch_size//2 > 0)&(filtered_df.px+patch_size//2 < global_data.shape[1])&(filtered_df.py-patch_size//2 > 0)&(filtered_df.py+patch_size//2 < global_data.shape[0]) | |
| filtered_df = filtered_df[mask] | |
| else: | |
| filtered_df = None | |
| if not filtered_df is None: | |
| # Sort the dataframe based on the sorting method | |
| if sort_method == "by Catalogue": | |
| filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True]) | |
| filtered_df = filtered_df.sort_values(by='TYPE', ascending=True).reset_index(drop=True) | |
| elif sort_method == "by x": | |
| filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True) | |
| elif sort_method == "by y": | |
| filtered_df = filtered_df.sort_values(by=['py', 'px'], ascending=[True, True]).reset_index(drop=True) | |
| try: | |
| wcs = WCS(global_header).dropaxis(2) | |
| ratio = global_data.shape[0]/global_data.shape[1] | |
| # Plot WCS | |
| fig = plt.figure(figsize=(ratio*scale,scale)) | |
| ax = fig.add_subplot(projection=wcs, label='overlays') | |
| ax.imshow(global_data, origin='lower') | |
| #if not filtered_df is None: | |
| # filtered_df.plot.scatter(x='px', y='py', ax=ax, s=15, c=patch_color) | |
| if "with Grid" in selected_axis_options: | |
| ax.coords.grid(True, color='white', ls='-', alpha=.5) | |
| if "with Axis Annotation" in selected_axis_options: | |
| ax.coords[0].set_axislabel('Right Ascension (J2000)', fontsize=fontsize+2) | |
| ax.coords[1].set_axislabel('Declination (J2000)', fontsize=fontsize+2) | |
| else: | |
| ax.axis('off') | |
| plt.title(title, fontsize=fontsize+4) | |
| if not filtered_df is None: | |
| all_patches = [] | |
| for i,row in filtered_df.iterrows(): | |
| rect = patches.Rectangle((row.px-patch_size//2, row.py-patch_size//2), patch_size, patch_size, alpha=alpha, linewidth=linewidth, edgecolor=patch_color, facecolor='none') | |
| ax.add_patch(rect) | |
| ax.text(row.px,row.py+patch_size//2,str(i+1), | |
| ha='center',va='bottom',color=patch_color,fontsize=fontsize) | |
| patch = global_data[row.py-patch_size//2:row.py+patch_size//2,row.px-patch_size//2:row.px+patch_size//2] | |
| all_patches.append(patch) | |
| plt.tight_layout() | |
| # Convert the plot to an image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200) | |
| plt.close(fig) | |
| buf.seek(0) | |
| # Convert buffer to PIL Image | |
| image = Image.open(buf) | |
| if not filtered_df is None: | |
| m = num_rows | |
| n = int(np.ceil(len(filtered_df)/m)) | |
| second_scale=max([1,scale//3]) | |
| fig, axarr = plt.subplots(n,m,figsize=(m*second_scale,n*second_scale)) | |
| for i, row in filtered_df.iterrows(): | |
| ax = axarr[i//m,i%m] | |
| ax.imshow(all_patches[i][::-1]) | |
| ax.set_title(row.main_id, fontsize=fontsize-2) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.text(2,2,str(i+1)[:30],ha='left',va='top',fontsize=fontsize+6) | |
| for i in np.arange(len(all_patches),m*n): | |
| ax = axarr[i//m,i%m] | |
| ax.axis('off') | |
| plt.tight_layout() | |
| # Convert the plot to an image | |
| second_buf = io.BytesIO() | |
| plt.savefig(second_buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200) | |
| plt.close(fig) | |
| second_buf.seek(0) | |
| # Convert buffer to PIL Image | |
| patches_image = Image.open(second_buf) | |
| return filtered_df, image, patches_image | |
| else: | |
| return filtered_df, image, None | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio interface | |
| with gr.Blocks(css=".btn-green {background-color: green; color: white;}") as gui: | |
| gr.Markdown("# What's in my image?") | |
| # Options Area | |
| with gr.Row() as options_gui: | |
| num_rows = gr.Number(label="Number of Rows", value=16, minimum=2, precision=0, interactive=True) | |
| title = gr.Textbox(label="Image Title", value="Custom Title", interactive=True) | |
| patch_size = gr.Slider(label="Patch Size", minimum=16, maximum=128, step=8, value=32, | |
| interactive=True) | |
| fontsize = gr.Slider(label="Fontsize", minimum=6, maximum=26, step=1, value=10, | |
| interactive=True) | |
| alpha = gr.Slider(label="Alpha", minimum=0., maximum=1., step=.1, value=1., | |
| interactive=True) | |
| linewidth = gr.Slider(label="Linewidth", minimum=1, maximum=4, step=1, value=1, | |
| interactive=True) | |
| scale = gr.Slider(label="Scale", minimum=1, maximum=20, step=1, value=10, | |
| interactive=True) | |
| patch_color = gr.ColorPicker(label="Patch Color", value="#FFFFFF", interactive=True) | |
| sort_method = gr.Dropdown(label="Sorting Method", choices=["by Catalogue", "by x", "by y"], value="by Catalogue", interactive=True) | |
| axis_options = gr.CheckboxGroup( | |
| label="Select options", | |
| choices=["with Grid", "with Axis Annotation"], | |
| value=["with Grid", "with Axis Annotation"], # Preselected values | |
| interactive=True # Makes it interactive | |
| ) | |
| gr.Markdown("Upload a plate solved `.fits` file (32 bit) to display its content.") | |
| file_input = gr.File(label="Upload .fits File", type="filepath") | |
| #file_input_csv = gr.File(label="Upload .csv File") | |
| greet_button = gr.Button("Query Simbad for Galaxies") # Create the button | |
| fits_image = gr.Image(label="Input Image", type="pil") | |
| type_checkboxes = gr.CheckboxGroup(label="Select Catalogue") | |
| patches_image = gr.Image(label="Patches Image", type="pil") | |
| csv_table = gr.DataFrame(label="CSV Table") | |
| track_options = [type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method] | |
| file_input.change(load_fits_image, | |
| inputs=[file_input] + track_options, | |
| outputs=[csv_table,fits_image,patches_image]) | |
| for option_i in track_options: | |
| option_i.change(update_images_and_tables, | |
| inputs=track_options, | |
| outputs=[csv_table,fits_image,patches_image]) | |
| # Display CSV table | |
| #file_input_csv.change(show_csv, | |
| # inputs=file_input_csv, | |
| # outputs=[csv_table, type_checkboxes]) | |
| greet_button.click(query_update_table, inputs=None, outputs=[csv_table, type_checkboxes]) | |
| # Update the selected checkboxes change | |
| type_checkboxes.change(update_images_and_tables, | |
| inputs=track_options, | |
| outputs=[csv_table,fits_image,patches_image]) | |
| gui.launch(debug=True) |