Edupy commited on
Commit
657801f
·
1 Parent(s): 190cb6d

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -9,16 +9,18 @@ categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
12
- # Reconstruimos la arquitectura ConvNeXt Tiny
13
- # num_classes debe coincidir con el número de categorías
14
  model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
15
 
16
- # Creamos un Learner mínimo. Al usar dls.from_empty no hay procesos de carga de datos.
17
- dls = DataLoaders.from_empty(categories)
 
 
 
18
  learn = Learner(dls, model, loss_func=CrossEntropyLossFlat())
19
 
20
  # Cargamos los pesos (.pth)
21
- # Fastai buscará el archivo weights_path + '.pth' (checkpoint_1.pth)
22
  learn.load(weights_path)
23
  return learn
24
 
@@ -26,10 +28,11 @@ def load_clean_model(weights_path):
26
  learn = load_clean_model('checkpoint_1')
27
 
28
  def predict(img):
29
- # PILImage.create prepara la imagen para Fastai
30
  img = PILImage.create(img).resize((126, 126))
31
 
32
- # El método predict de Learner aplicará las transformaciones necesarias
 
33
  pred, pred_idx, probs = learn.predict(img)
34
 
35
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
@@ -40,7 +43,7 @@ demo = gr.Interface(
40
  inputs=gr.Image(type="pil"),
41
  outputs=gr.Label(num_top_classes=3),
42
  title="Pokemon Type Classifier",
43
- description="Modelo ConvNeXt Tiny cargado mediante pesos (.pth)"
44
  )
45
 
46
  demo.launch()
 
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
12
+ # Creamos la arquitectura ConvNeXt Tiny
 
13
  model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
14
 
15
+ # FORMA COMPATIBLE: Creamos DataLoaders vacíos manualmente
16
+ # Esto define que el vocabulario son nuestras categorías
17
+ dls = DataLoaders(Datasets([], [[]]), vocab=categories)
18
+
19
+ # Creamos el Learner
20
  learn = Learner(dls, model, loss_func=CrossEntropyLossFlat())
21
 
22
  # Cargamos los pesos (.pth)
23
+ # Fastai busca el archivo weights_path + '.pth'
24
  learn.load(weights_path)
25
  return learn
26
 
 
28
  learn = load_clean_model('checkpoint_1')
29
 
30
  def predict(img):
31
+ # Preparar la imagen
32
  img = PILImage.create(img).resize((126, 126))
33
 
34
+ # En inferencia pura (sin dblock), usamos el modelo directamente o learn.predict
35
+ # learn.predict aplicará el preprocesamiento necesario
36
  pred, pred_idx, probs = learn.predict(img)
37
 
38
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
 
43
  inputs=gr.Image(type="pil"),
44
  outputs=gr.Label(num_top_classes=3),
45
  title="Pokemon Type Classifier",
46
+ description="Inferencia robusta con Fastai y ConvNeXt Tiny."
47
  )
48
 
49
  demo.launch()