Edupy commited on
Commit
7b6064b
·
1 Parent(s): 5dfab75

Forzando sobrescritura de app.py y carga de pesos .pth

Browse files
Files changed (1) hide show
  1. app.py +19 -21
app.py CHANGED
@@ -4,44 +4,42 @@ import gradio as gr
4
  import timm
5
  import torch
6
 
7
- # 1. Lista de categorías
8
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
10
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
11
 
12
- # 2. Carga del modelo
13
  def load_pokemon_model(weights_path):
14
- try:
15
- model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
16
- state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
17
- if 'model' in state_dict: state_dict = state_dict['model']
18
- new_state_dict = {k.replace('0.model.', ''): v for k, v in state_dict.items()}
19
- model.load_state_dict(new_state_dict, strict=False)
20
- model.eval()
21
- return model
22
- except Exception as e:
23
- return f"Error cargando pesos: {str(e)}"
24
 
25
  model = load_pokemon_model('checkpoint_1.pth')
26
 
27
- # 3. Función de predicción con captura de errores
28
  def predict(img):
29
  try:
30
- if isinstance(model, str): return {"Error de carga": model} # Si falló el paso 2
31
-
32
- img = PILImage.create(img).resize((224, 224))
33
 
34
- # Transformación manual a tensor
35
- img_tensor = cast(ToTensor()(img), TensorImage)
36
- img_tensor = IntToFloatTensor()(img_tensor)
 
 
37
 
 
38
  with torch.no_grad():
39
- output = model(img_tensor.unsqueeze(0))
40
  probs = torch.softmax(output, dim=1)[0]
41
 
42
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
43
  except Exception as e:
44
- return {"Error en predicción": str(e)}
 
 
45
 
46
  gr.Interface(
47
  fn=predict,
 
4
  import timm
5
  import torch
6
 
 
7
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
8
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
 
11
  def load_pokemon_model(weights_path):
12
+ model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
13
+ state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
14
+ if 'model' in state_dict: state_dict = state_dict['model']
15
+ new_state_dict = {k.replace('0.model.', ''): v for k, v in state_dict.items()}
16
+ model.load_state_dict(new_state_dict, strict=False)
17
+ model.eval()
18
+ return model
 
 
 
19
 
20
  model = load_pokemon_model('checkpoint_1.pth')
21
 
 
22
  def predict(img):
23
  try:
24
+ # 1. Preparar la imagen al tamaño exacto de tu entrenamiento (126)
25
+ img = PILImage.create(img).resize((126, 126))
 
26
 
27
+ # 2. Transformar a Tensor y Normalizar (Estándar de ConvNeXt/ImageNet)
28
+ # Esto soluciona el error de los canales (channels)
29
+ timg = TensorImage(image2tensor(img))
30
+ timg = IntToFloatTensor()(timg)
31
+ timg = Normalize.from_stats(*imagenet_stats)(timg)
32
 
33
+ # 3. Predicción
34
  with torch.no_grad():
35
+ output = model(timg.unsqueeze(0))
36
  probs = torch.softmax(output, dim=1)[0]
37
 
38
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
39
  except Exception as e:
40
+ # Devolvemos un string que Gradio pueda manejar en caso de error crítico
41
+ print(f"Error en predicción: {e}")
42
+ return {f"Error: {str(e)}": 0.0}
43
 
44
  gr.Interface(
45
  fn=predict,