Edupy commited on
Commit
005f111
·
1 Parent(s): 7b6064b

Forzando sobrescritura de app.py y carga de pesos .pth

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -21,24 +21,24 @@ model = load_pokemon_model('checkpoint_1.pth')
21
 
22
  def predict(img):
23
  try:
24
- # 1. Preparar la imagen al tamaño exacto de tu entrenamiento (126)
25
  img = PILImage.create(img).resize((126, 126))
26
 
27
- # 2. Transformar a Tensor y Normalizar (Estándar de ConvNeXt/ImageNet)
28
- # Esto soluciona el error de los canales (channels)
29
- timg = TensorImage(image2tensor(img))
30
- timg = IntToFloatTensor()(timg)
31
- timg = Normalize.from_stats(*imagenet_stats)(timg)
 
 
 
32
 
33
- # 3. Predicción
34
  with torch.no_grad():
35
- output = model(timg.unsqueeze(0))
36
  probs = torch.softmax(output, dim=1)[0]
37
 
38
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
39
  except Exception as e:
40
- # Devolvemos un string que Gradio pueda manejar en caso de error crítico
41
- print(f"Error en predicción: {e}")
42
  return {f"Error: {str(e)}": 0.0}
43
 
44
  gr.Interface(
 
21
 
22
  def predict(img):
23
  try:
24
+ # 1. Redimensionar
25
  img = PILImage.create(img).resize((126, 126))
26
 
27
+ # 2. Convertir a tensor y normalizar (forma directa de PyTorch)
28
+ # Usamos transforms estándar para asegurar las 4 dimensiones [1, 3, 126, 126]
29
+ timg = Pipeline([ToTensor(), IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)])(img)
30
+
31
+ # 3. Predicción (timg ya sale con el batch dimension si se usa Pipeline de fastai adecuadamente)
32
+ # Pero para estar seguros de que tiene 4D y no 5D:
33
+ if timg.ndim == 3: timg = timg.unsqueeze(0)
34
+ if timg.ndim == 5: timg = timg.squeeze(0) # Esto elimina la dimensión extra si existiera
35
 
 
36
  with torch.no_grad():
37
+ output = model(timg)
38
  probs = torch.softmax(output, dim=1)[0]
39
 
40
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
41
  except Exception as e:
 
 
42
  return {f"Error: {str(e)}": 0.0}
43
 
44
  gr.Interface(