Spaces:
Build error
Build error
| import io | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image as PILImage | |
| import torch | |
| from era_data import TabletPeriodDataset | |
| from VAE_model_tablets_class import VAE | |
| import gradio as gr | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| num_classes = len(TabletPeriodDataset.PERIOD_INDICES) | |
| class_weights = torch.load("class_weights_period.pt") | |
| checkpoint_path = 'epoch=22-step=213621.ckpt' | |
| vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes, | |
| loss_type="weighted", class_weights=class_weights, device = device) | |
| vae_model.eval() | |
| # Load your dataframe encoding | |
| df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv') | |
| df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index() | |
| period_names = df_means['Period_Name'].unique() | |
| def get_image_from_period(period_name): | |
| period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32')) | |
| return period_data | |
| def generate_image(period1, period2, interpolation_value): | |
| image1 = get_image_from_period(period1) | |
| image2 = get_image_from_period(period2) | |
| i = interpolation_value | |
| new_tablet = (1-i) * image1 + i * image2 | |
| new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0) | |
| with torch.no_grad(): | |
| generated_image = vae_model.decoder(new_tab_long) | |
| generated_image = generated_image[0][0].detach().cpu().numpy() | |
| generated_image = (generated_image * 255).astype(np.uint8) | |
| pil_img = PILImage.fromarray(generated_image) | |
| return pil_img | |
| def update_image(dropdown1, dropdown2, slider): | |
| iface.update(gr.Image(value=generate_image(dropdown1, dropdown2, slider))) | |
| with gr.Blocks() as inputOutput: | |
| dropdown1 = gr.Dropdown(choices=period_names.tolist(), label="Period 1") | |
| dropdown2 = gr.Dropdown(choices=period_names.tolist(), label="Period 2") | |
| slider = gr.Slider(0, 1, step=0.1, label="Interpolation") | |
| slider.change(update_image, dropdown1, dropdown2, slider) | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_image, | |
| inputs=[dropdown1, dropdown2, slider], | |
| outputs=gr.Image(label="", height=250, width=250), | |
| allow_flagging="never") | |
| if __name__ == "__main__": | |
| iface.launch(share=True) | |