File size: 2,410 Bytes
991bb01
 
 
 
 
 
 
 
 
 
 
 
6d53984
 
c25bef4
991bb01
de1fe70
c25bef4
 
 
 
3dc57ed
991bb01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc57ed
991bb01
 
3dc57ed
991bb01
 
 
14fb364
991bb01
5919981
 
 
 
 
 
 
 
 
 
 
 
991bb01
 
5919981
 
 
 
ecf5b45
991bb01
8ba0f0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import io

import numpy as np
import pandas as pd
from PIL import Image as PILImage
import torch

from era_data import TabletPeriodDataset
from VAE_model_tablets_class import VAE

import gradio as gr

device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_classes = len(TabletPeriodDataset.PERIOD_INDICES)

class_weights = torch.load("class_weights_period.pt")

checkpoint_path = 'epoch=22-step=213621.ckpt'
vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
            loss_type="weighted", class_weights=class_weights, device = device)
vae_model.eval()

# Load your dataframe encoding
df_encodings = pd.read_csv('df_vae_encoding_April16_all.csv')
df_means = df_encodings.drop(["Period", "Genre", "Genre_Name", "CDLI_id"], axis = 1).groupby("Period_Name").mean().reset_index()
period_names = df_means['Period_Name'].unique()

def get_image_from_period(period_name):
    period_data = torch.from_numpy(df_means[df_means["Period_Name"] == period_name].drop(["Period_Name"], axis=1).values[0].astype('float32'))
    return period_data

def generate_image(period1, period2, interpolation_value):
    image1 = get_image_from_period(period1)
    image2 = get_image_from_period(period2)
    
    i = interpolation_value
    new_tablet = (1-i) * image1 + i * image2
    new_tab_long = vae_model.fc3(new_tablet).unsqueeze(0)
    
    with torch.no_grad():
        generated_image = vae_model.decoder(new_tab_long)
    generated_image = generated_image[0][0].detach().cpu().numpy()
    generated_image = (generated_image * 255).astype(np.uint8)
    pil_img = PILImage.fromarray(generated_image)
    return pil_img

def update_image(dropdown1, dropdown2, slider):
    iface.update(gr.Image(value=generate_image(dropdown1, dropdown2, slider)))


with gr.Blocks() as inputOutput:
    dropdown1 = gr.Dropdown(choices=period_names.tolist(), label="Period 1")
    dropdown2 = gr.Dropdown(choices=period_names.tolist(), label="Period 2")
    slider = gr.Slider(0, 1, step=0.1, label="Interpolation")

    slider.change(update_image, dropdown1, dropdown2, slider)

# Define the Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[dropdown1, dropdown2, slider],
    outputs=gr.Image(label="", height=250, width=250),
    allow_flagging="never")


if __name__ == "__main__":
    iface.launch(share=True)