Rarsta commited on
Commit
38bb4ea
·
verified ·
1 Parent(s): 3004d69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ from tensorflow import keras
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
+
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ PATH_MODEL = "./autoencoder.keras"
12
+ PATH_DB = "./mnist_train_small.csv"
13
+
14
+ # ── Cargar modelo y datos al iniciar ─────────────────────────────────────────
15
+ model = keras.models.load_model(PATH_MODEL)
16
+ encoder = model.get_layer("encoder")
17
+ decoder = model.get_layer("decoder")
18
+
19
+ data = pd.read_csv(PATH_DB, header=None)
20
+ X_ref = data.iloc[:, 1:].values.astype("float32") / 255
21
+ X_latent = encoder.predict(X_ref, verbose=0)
22
+
23
+ LATENT_DIM = 32
24
+
25
+ # ── Helper: imagen subida → array (1, 784) ────────────────────────────────────
26
+ def image_to_array(canva):
27
+ img = canva['composite'].convert("L")
28
+ img = img.resize((28, 28))
29
+ arr = 1 - np.array(img, dtype="float32") / 255
30
+ return arr.reshape(1, 784)
31
+
32
+ def find_similar(img, top_k):
33
+ X = image_to_array(img)
34
+ query_vec = encoder.predict(X, verbose=0)
35
+ sims = cosine_similarity(query_vec, X_latent)[0]
36
+ top_idx = np.argsort(sims)[::-1][:int(top_k)]
37
+
38
+ best_arr = (X_ref[top_idx[0]].reshape(28, 28) * 255).astype(np.uint8)
39
+ best_img = Image.fromarray(best_arr)
40
+
41
+ table = [[int(i), round(float(sims[i]), 4)] for i in top_idx]
42
+
43
+ gallery_imgs = [
44
+ Image.fromarray((X_ref[i].reshape(28, 28) * 255).astype(np.uint8))
45
+ for i in top_idx
46
+ ]
47
+ return table, best_img, gallery_imgs
48
+
49
+ with gr.Blocks() as demo:
50
+
51
+ with gr.Tab("Búsqueda"):
52
+
53
+ gr.Markdown("## Búsqueda en espacio latente")
54
+ with gr.Row():
55
+ with gr.Column():
56
+ canvas = gr.Sketchpad(label="Dibuja", type='pil')
57
+ with gr.Column():
58
+ topk = gr.Slider(1, 50, value=10, step=1, label="top_k")
59
+ btn = gr.Button("Buscar similares")
60
+
61
+ gallery = gr.Gallery(label="Imágenes similares", columns=5, object_fit="contain")
62
+
63
+ with gr.Tab("Metadatos"):
64
+
65
+ results = gr.Dataframe(
66
+ headers=["index", "cosine_similarity"],
67
+ datatype=["number", "number"],
68
+ label="Ranking",
69
+ interactive=False
70
+ )
71
+
72
+ btn.click(find_similar, inputs=[canvas, topk], outputs=[results, best, gallery])
73
+ demo.launch(server_port=7860)