import io import numpy as np import pandas as pd from PIL import Image as PILImage import torch from era_data import TabletPeriodDataset from VAE_model_tablets_class import VAE import gradio as gr device = 'cuda' if torch.cuda.is_available() else 'cpu' num_classes = len(TabletPeriodDataset.PERIOD_INDICES) class_weights = torch.load("class_weights_period.pt") checkpoint_path = 'epoch=22-step=213621.ckpt' vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes, loss_type="weighted", class_weights=class_weights, device = device) vae_model.eval() # Load your dataframe encoding df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv') df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index() period_names = df_means['Period_Name'].unique() def get_image_from_period(period_name): period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32')) return period_data def generate_image(period1, period2, interpolation_value): image1 = get_image_from_period(period1) image2 = get_image_from_period(period2) i = interpolation_value new_tablet = (1-i) * image1 + i * image2 new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0) with torch.no_grad(): generated_image = vae_model.decoder(new_tab_long) generated_image = generated_image[0][0].detach().cpu().numpy() generated_image = (generated_image * 255).astype(np.uint8) pil_img = PILImage.fromarray(generated_image) return pil_img def update_image(dropdown1, dropdown2, slider): iface.update(gr.Image(value=generate_image(dropdown1, dropdown2, slider))) with gr.Blocks() as inputOutput: dropdown1 = gr.Dropdown(choices=period_names.tolist(), label="Period 1") dropdown2 = gr.Dropdown(choices=period_names.tolist(), label="Period 2") slider = gr.Slider(0, 1, step=0.1, label="Interpolation") slider.change(update_image, dropdown1, dropdown2, slider) # Define the Gradio interface iface = gr.Interface( fn=generate_image, inputs=[dropdown1, dropdown2, slider], outputs=gr.Image(label="", height=250, width=250), allow_flagging="never") if __name__ == "__main__": iface.launch(share=True)