DaniKaEp's picture
Update app.py
8ba0f0d verified
import io
import numpy as np
import pandas as pd
from PIL import Image as PILImage
import torch
from era_data import TabletPeriodDataset
from VAE_model_tablets_class import VAE
import gradio as gr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
class_weights = torch.load("class_weights_period.pt")
checkpoint_path = 'epoch=22-step=213621.ckpt'
vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
loss_type="weighted", class_weights=class_weights, device = device)
vae_model.eval()
# Load your dataframe encoding
df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index()
period_names = df_means['Period_Name'].unique()
def get_image_from_period(period_name):
period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32'))
return period_data
def generate_image(period1, period2, interpolation_value):
image1 = get_image_from_period(period1)
image2 = get_image_from_period(period2)
i = interpolation_value
new_tablet = (1-i) * image1 + i * image2
new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0)
with torch.no_grad():
generated_image = vae_model.decoder(new_tab_long)
generated_image = generated_image[0][0].detach().cpu().numpy()
generated_image = (generated_image * 255).astype(np.uint8)
pil_img = PILImage.fromarray(generated_image)
return pil_img
def update_image(dropdown1, dropdown2, slider):
iface.update(gr.Image(value=generate_image(dropdown1, dropdown2, slider)))
with gr.Blocks() as inputOutput:
dropdown1 = gr.Dropdown(choices=period_names.tolist(), label="Period 1")
dropdown2 = gr.Dropdown(choices=period_names.tolist(), label="Period 2")
slider = gr.Slider(0, 1, step=0.1, label="Interpolation")
slider.change(update_image, dropdown1, dropdown2, slider)
# Define the Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[dropdown1, dropdown2, slider],
outputs=gr.Image(label="", height=250, width=250),
allow_flagging="never")
if __name__ == "__main__":
iface.launch(share=True)