Edupy commited on
Commit
fa901be
·
1 Parent(s): a739378

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -6,22 +6,23 @@ from PIL import Image
6
  from torchvision import transforms
7
  from fastai.vision.all import *
8
 
9
- # 1. Categorías (Asegúrate de que el orden sea ALFABÉTICO, que es el defecto de Fastai)
10
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
11
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
12
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
13
 
14
  def load_model_fastai_style(weights_path):
15
- # Creamos un Learner vacío para que Fastai construya la arquitectura COMPLETA
16
- # (Cuerpo de ConvNeXt + Cabeza de clasificación de Fastai)
17
- dls = DataLoaders.from_empty(categories)
 
 
 
 
18
  learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
19
 
20
- # Cargamos el archivo .pth
21
- # Usamos torch.load directamente para mayor control
22
  state = torch.load(weights_path, map_location='cpu', weights_only=False)
23
-
24
- # Si guardaste con learn.save, los pesos están en state['model']
25
  if isinstance(state, dict) and 'model' in state:
26
  learn.model.load_state_dict(state['model'])
27
  else:
@@ -34,23 +35,24 @@ def load_model_fastai_style(weights_path):
34
  learn = load_model_fastai_style('checkpoint_1.pth')
35
 
36
  def predict(img):
37
- # Usamos PILImage de fastai para que el preprocesamiento sea IDÉNTICO a Colab
38
  img = PILImage.create(img)
39
-
40
- # IMPORTANTE: Forzamos el resize al tamaño de entrenamiento
41
  img = img.resize((126, 126))
42
 
43
- # Usamos el método predict del learner, que ya sabe normalizar
 
44
  pred, pred_idx, probs = learn.predict(img)
45
 
46
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
47
 
 
48
  demo = gr.Interface(
49
  fn=predict,
50
  inputs=gr.Image(type="pil"),
51
  outputs=gr.Label(num_top_classes=3),
52
- title="Pokemon Type Classifier (Sincronizado)",
53
- description="Inferencia con arquitectura completa de Fastai."
54
  )
55
 
56
  demo.launch()
 
6
  from torchvision import transforms
7
  from fastai.vision.all import *
8
 
9
+ # 1. Categorías (Asegúrate de que el orden sea el mismo que dls.vocab en Colab)
10
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
11
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
12
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
13
 
14
  def load_model_fastai_style(weights_path):
15
+ # En lugar de from_empty, creamos dls mínimos con Datasets vacíos
16
+ # Esto es universal y no depende de métodos de clase que cambian
17
+ empty_ds = Datasets([None], [[]])
18
+ dls = DataLoaders.from_dsets(empty_ds, empty_ds, path='.', bs=1, device='cpu')
19
+ dls.vocab = categories
20
+
21
+ # Construimos el Learner
22
  learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
23
 
24
+ # Cargar pesos con control de errores
 
25
  state = torch.load(weights_path, map_location='cpu', weights_only=False)
 
 
26
  if isinstance(state, dict) and 'model' in state:
27
  learn.model.load_state_dict(state['model'])
28
  else:
 
35
  learn = load_model_fastai_style('checkpoint_1.pth')
36
 
37
  def predict(img):
38
+ # 1. Convertir a PILImage de Fastai
39
  img = PILImage.create(img)
40
+ # 2. Resize manual al tamaño de entrenamiento
 
41
  img = img.resize((126, 126))
42
 
43
+ # 3. Predicción
44
+ # El método predict gestiona internamente la normalización que el modelo espera
45
  pred, pred_idx, probs = learn.predict(img)
46
 
47
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
48
 
49
+ # Interfaz
50
  demo = gr.Interface(
51
  fn=predict,
52
  inputs=gr.Image(type="pil"),
53
  outputs=gr.Label(num_top_classes=3),
54
+ title="Pokemon Type Classifier",
55
+ description="Identifica el tipo principal de tu Pokémon."
56
  )
57
 
58
  demo.launch()