Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ def image_to_base64(image_path):
|
|
| 19 |
st.markdown("""
|
| 20 |
<style>
|
| 21 |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
|
| 22 |
-
|
| 23 |
/* Apply the font to everything */
|
| 24 |
html, body, [class*="st"] {
|
| 25 |
font-family: 'Roboto', sans-serif;
|
|
@@ -130,7 +129,15 @@ if 'generate' not in st.session_state:
|
|
| 130 |
|
| 131 |
# Inizializza inference_tester solo una volta
|
| 132 |
if 'inference_tester' not in st.session_state:
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# Usa inference_tester dalla sessione
|
| 136 |
inference_tester = st.session_state['inference_tester']
|
|
@@ -202,12 +209,18 @@ if st.session_state['step'] == 2:
|
|
| 202 |
|
| 203 |
# Pulsante per provare un esempio
|
| 204 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
if st.button("Try an example"):
|
| 206 |
st.session_state['step'] = 5 # Passa al passo 5
|
| 207 |
st.rerun()
|
| 208 |
|
| 209 |
# Pulsante per tornare all'inizio
|
| 210 |
-
with
|
| 211 |
if st.button("Return to the beginning"):
|
| 212 |
# Ripristina lo stato della sessione
|
| 213 |
st.session_state['step'] = 1
|
|
@@ -365,8 +378,79 @@ if st.session_state['step'] == 3:
|
|
| 365 |
st.rerun()
|
| 366 |
|
| 367 |
if st.session_state['step'] == 4:
|
| 368 |
-
|
| 369 |
-
st.session_state['generate']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
if st.button("Return to the beginning"):
|
| 372 |
# Ripristina lo stato della sessione
|
|
|
|
| 19 |
st.markdown("""
|
| 20 |
<style>
|
| 21 |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
|
|
|
|
| 22 |
/* Apply the font to everything */
|
| 23 |
html, body, [class*="st"] {
|
| 24 |
font-family: 'Roboto', sans-serif;
|
|
|
|
| 129 |
|
| 130 |
# Inizializza inference_tester solo una volta
|
| 131 |
if 'inference_tester' not in st.session_state:
|
| 132 |
+
model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_video_diffuser_8frames.pth']
|
| 133 |
+
st.session_state['inference_tester'] = dani_model(model='thesis_model',
|
| 134 |
+
data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
|
| 135 |
+
pth=model_load_paths, load_weights=False)
|
| 136 |
+
inference_tester = st.session_state['inference_tester']
|
| 137 |
+
|
| 138 |
+
# Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
|
| 139 |
+
if 'weights_loaded' not in st.session_state:
|
| 140 |
+
st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
|
| 141 |
|
| 142 |
# Usa inference_tester dalla sessione
|
| 143 |
inference_tester = st.session_state['inference_tester']
|
|
|
|
| 209 |
|
| 210 |
# Pulsante per provare un esempio
|
| 211 |
with col1:
|
| 212 |
+
if st.button("Inference"):
|
| 213 |
+
st.session_state['step'] = 3 # Passa al passo 3
|
| 214 |
+
st.rerun()
|
| 215 |
+
|
| 216 |
+
# Pulsante per provare un esempio
|
| 217 |
+
with col2:
|
| 218 |
if st.button("Try an example"):
|
| 219 |
st.session_state['step'] = 5 # Passa al passo 5
|
| 220 |
st.rerun()
|
| 221 |
|
| 222 |
# Pulsante per tornare all'inizio
|
| 223 |
+
with col3:
|
| 224 |
if st.button("Return to the beginning"):
|
| 225 |
# Ripristina lo stato della sessione
|
| 226 |
st.session_state['step'] = 1
|
|
|
|
| 378 |
st.rerun()
|
| 379 |
|
| 380 |
if st.session_state['step'] == 4:
|
| 381 |
+
# Costruzione del prompt
|
| 382 |
+
if st.session_state['generate'] is True:
|
| 383 |
+
conditioning = []
|
| 384 |
+
for inp in st.session_state['inputs']:
|
| 385 |
+
if inp == 'frontal':
|
| 386 |
+
cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
|
| 387 |
+
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
|
| 388 |
+
encode_type='encode_vision').to(device)
|
| 389 |
+
conditioning.append(torch.cat([uim, cim]))
|
| 390 |
+
elif inp == 'lateral':
|
| 391 |
+
cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
|
| 392 |
+
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
|
| 393 |
+
encode_type='encode_vision').to(device)
|
| 394 |
+
conditioning.append(torch.cat([uim, cim]))
|
| 395 |
+
elif inp == 'text':
|
| 396 |
+
ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
|
| 397 |
+
utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
|
| 398 |
+
conditioning.append(torch.cat([utx, ctx]))
|
| 399 |
+
|
| 400 |
+
# Costruzione delle shapes
|
| 401 |
+
shapes = []
|
| 402 |
+
for out in st.session_state['outputs']:
|
| 403 |
+
if out == 'frontal' or out == 'lateral':
|
| 404 |
+
shape = [1, 4, 256 // 8, 256 // 8]
|
| 405 |
+
shapes.append(shape)
|
| 406 |
+
elif out == 'text':
|
| 407 |
+
shape = [1, 768]
|
| 408 |
+
shapes.append(shape)
|
| 409 |
+
|
| 410 |
+
progress_bar = st.progress(0)
|
| 411 |
+
|
| 412 |
+
# Inferenza
|
| 413 |
+
z, _ = inference_tester.sampler.sample(
|
| 414 |
+
steps=50,
|
| 415 |
+
shape=shapes,
|
| 416 |
+
condition=conditioning,
|
| 417 |
+
unconditional_guidance_scale=7.5,
|
| 418 |
+
xtype=st.session_state['outputs'],
|
| 419 |
+
condition_types=st.session_state['inputs'],
|
| 420 |
+
eta=1,
|
| 421 |
+
verbose=False,
|
| 422 |
+
mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
|
| 423 |
+
progress_bar=progress_bar)
|
| 424 |
+
|
| 425 |
+
# Decoder e visualizzazione dei risultati
|
| 426 |
+
output_cols = st.columns(len(st.session_state['outputs']))
|
| 427 |
+
|
| 428 |
+
# Definire due colonne per le immagini
|
| 429 |
+
col1, col2 = st.columns(2)
|
| 430 |
+
|
| 431 |
+
# Iterare sugli output e assegnare le immagini alle colonne corrispondenti
|
| 432 |
+
for i, out in enumerate(st.session_state['outputs']):
|
| 433 |
+
if out == 'frontal':
|
| 434 |
+
x = inference_tester.net.autokl_decode(z[i])
|
| 435 |
+
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 436 |
+
im = x[0].cpu().numpy()
|
| 437 |
+
with col1: # Mostrare la frontal image nella prima colonna
|
| 438 |
+
st.image(im, caption="Generated Frontal Image")
|
| 439 |
+
elif out == 'lateral':
|
| 440 |
+
x = inference_tester.net.autokl_decode(z[i])
|
| 441 |
+
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
| 442 |
+
im = x[0].cpu().numpy()
|
| 443 |
+
with col2: # Mostrare la lateral image nella seconda colonna
|
| 444 |
+
st.image(im, caption="Generated Lateral Image")
|
| 445 |
+
elif out == 'text':
|
| 446 |
+
x = inference_tester.net.optimus_decode(z[i], max_length=100)
|
| 447 |
+
x = [a.tolist() for a in x]
|
| 448 |
+
rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
|
| 449 |
+
rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
|
| 450 |
+
st.write(f"Generated Report: {rec_text}")
|
| 451 |
+
|
| 452 |
+
st.write("Generation completed successfully!")
|
| 453 |
+
st.session_state['generate'] = False
|
| 454 |
|
| 455 |
if st.button("Return to the beginning"):
|
| 456 |
# Ripristina lo stato della sessione
|