Edupy commited on
Commit
1fa1da9
·
1 Parent(s): 1d6d003

Forzando sobrescritura de app.py y carga de pesos .pth

Browse files
Files changed (1) hide show
  1. app.py +31 -14
app.py CHANGED
@@ -1,30 +1,47 @@
 
1
  from fastai.vision.all import *
2
  import gradio as gr
3
  import timm
4
 
 
5
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
6
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
7
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
8
 
9
- def load_model_with_weights(weights_path):
10
- # Inicialización mínima
11
- dls = ImageDataLoaders.from_lists('.', ['dummy.jpg'], [categories[0]],
12
- item_tfms=Resize(460),
13
- batch_tfms=aug_transforms(size=126))
14
 
15
- # IMPORTANTE: Usamos la arquitectura exacta
16
- learn = vision_learner(dls, 'convnext_tiny', metrics=accuracy)
 
17
 
18
- # Cargar pesos (.pth). Fastai busca el archivo sin la extensión .pth internamente a veces,
19
- # pero si le pasas el path completo suele funcionar.
20
- load_model(weights_path, learn.model, learn.opt)
21
- return learn
 
 
 
 
22
 
23
  # Cargamos el modelo
24
- learn = load_model_with_weights('checkpoint_1.pth')
25
 
 
26
  def predict(img):
27
- img = PILImage.create(img)
 
 
 
 
 
 
 
 
 
28
  pred, pred_idx, probs = learn.predict(img)
29
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
30
 
@@ -32,5 +49,5 @@ gr.Interface(
32
  fn=predict,
33
  inputs=gr.Image(),
34
  outputs=gr.Label(num_top_classes=3),
35
- title="Detector de Tipos Pokémon (Carga por Pesos)"
36
  ).launch()
 
1
+
2
  from fastai.vision.all import *
3
  import gradio as gr
4
  import timm
5
 
6
+ # 1. Lista de categorías exacta
7
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
8
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
+ # 2. Reconstruir el modelo sin necesidad de DataLoaders ni imágenes dummy
12
+ def load_pokemon_model(weights_path):
13
+ # Creamos el modelo usando timm (la arquitectura exacta que entrenaste)
14
+ # num_classes debe coincidir con el número de tipos de tu dls.vocab
15
+ model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
16
 
17
+ # Cargar los pesos directamente al modelo de PyTorch
18
+ # load_model de fastai espera un archivo .pth generado con learn.save
19
+ state_dict = torch.load(weights_path, map_location='cpu')
20
 
21
+ # Si guardaste con learn.save, los pesos están en la llave 'model'
22
+ if 'model' in state_dict:
23
+ model.load_state_dict(state_dict['model'])
24
+ else:
25
+ model.load_state_dict(state_dict)
26
+
27
+ model.eval()
28
+ return model
29
 
30
  # Cargamos el modelo
31
+ model = load_pokemon_model('checkpoint_1.pth')
32
 
33
+ # Función de predicción usando el modelo de PyTorch directamente
34
  def predict(img):
35
+ img = PILImage.create(img).resize((126, 126)) # El tamaño 'size' que usaste en batch_tfms
36
+ # Convertir imagen a tensor y normalizar
37
+ img_tensor = pipeline(img) # fastai maneja esto internamente con predict,
38
+ # pero aquí lo simplificamos:
39
+
40
+ # Usamos un Learner vacío solo para aprovechar el método predict de fastai
41
+ # sin errores de inicialización
42
+ empty_dls = DataLoaders.from_empty(categories)
43
+ learn = Learner(empty_dls, model, loss_func=CrossEntropyLossFlat())
44
+
45
  pred, pred_idx, probs = learn.predict(img)
46
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
47
 
 
49
  fn=predict,
50
  inputs=gr.Image(),
51
  outputs=gr.Label(num_top_classes=3),
52
+ title="Detector de Tipos Pokémon"
53
  ).launch()