Spaces:
Runtime error
Runtime error
Forzando sobrescritura de app.py y carga de pesos .pth
Browse files
app.py
CHANGED
|
@@ -21,20 +21,23 @@ model = load_pokemon_model('checkpoint_1.pth')
|
|
| 21 |
|
| 22 |
def predict(img):
|
| 23 |
try:
|
| 24 |
-
# 1.
|
| 25 |
img = PILImage.create(img).resize((126, 126))
|
| 26 |
|
| 27 |
-
# 2. Convertir a tensor
|
| 28 |
-
#
|
| 29 |
-
timg =
|
|
|
|
| 30 |
|
| 31 |
-
# 3.
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
|
| 36 |
with torch.no_grad():
|
| 37 |
-
output = model(
|
| 38 |
probs = torch.softmax(output, dim=1)[0]
|
| 39 |
|
| 40 |
return {categories[i]: float(probs[i]) for i in range(len(categories))}
|
|
|
|
| 21 |
|
| 22 |
def predict(img):
|
| 23 |
try:
|
| 24 |
+
# 1. Procesar imagen PIL
|
| 25 |
img = PILImage.create(img).resize((126, 126))
|
| 26 |
|
| 27 |
+
# 2. Convertir a tensor manualmente para tener control total
|
| 28 |
+
# image2tensor devuelve un tensor de [3, 126, 126]
|
| 29 |
+
timg = image2tensor(img)
|
| 30 |
+
timg = timg.float()/255.0 # Normalizar a 0-1
|
| 31 |
|
| 32 |
+
# 3. Aplicar normalización de ImageNet (Media y Desviación estándar)
|
| 33 |
+
stats = imagenet_stats
|
| 34 |
+
timg = (timg - tensor(stats[0])[:,None,None]) / tensor(stats[1])[:,None,None]
|
| 35 |
+
|
| 36 |
+
# 4. Añadir dimensión de batch -> Resultado: [1, 3, 126, 126]
|
| 37 |
+
batch_img = timg.unsqueeze(0)
|
| 38 |
|
| 39 |
with torch.no_grad():
|
| 40 |
+
output = model(batch_img)
|
| 41 |
probs = torch.softmax(output, dim=1)[0]
|
| 42 |
|
| 43 |
return {categories[i]: float(probs[i]) for i in range(len(categories))}
|