Edupy commited on
Commit
7cdd69c
·
1 Parent(s): 919ec96

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -9,21 +9,19 @@ categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
12
- # Creamos la arquitectura ConvNeXt Tiny directamente desde timm
13
- # Esto no busca ni get_x, ni get_y, ni archivos dummy
14
- body = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
15
 
16
- # Creamos un Learner "vacío" solo para usar su método predict
17
- # DataLoaders.from_empty es la clave: no busca archivos en el disco
18
- dls = DataLoaders.from_empty(categories)
19
- learn = Learner(dls, body, loss_func=CrossEntropyLossFlat())
20
 
21
- # Cargamos los pesos (.pth)
22
- # Si tu archivo se llama checkpoint_1.pth, aquí ponemos 'checkpoint_1'
23
- load_model(weights_path, learn.model, learn.opt)
24
  return learn
25
 
26
- # Cargamos el modelo (asegúrate de que checkpoint_1.pth esté en el repo)
27
  learn = load_clean_model('checkpoint_1')
28
 
29
  def predict(img):
@@ -31,13 +29,13 @@ def predict(img):
31
  pred, pred_idx, probs = learn.predict(img)
32
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
33
 
34
- # Interfaz
35
  demo = gr.Interface(
36
  fn=predict,
37
  inputs=gr.Image(type="pil"),
38
  outputs=gr.Label(num_top_classes=3),
39
- title="Detector Pokémon (Sin Dependencias)",
40
- description="Carga directa de arquitectura y pesos."
41
  )
42
 
43
  demo.launch()
 
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
12
+ # En lugar de from_empty, creamos dls mínimos compatibles
13
+ # Esto define la estructura necesaria para que learn.predict funcione
14
+ dls = ImageDataLoaders.from_lists('.', [], [], vocab=categories, item_tfms=Resize(126))
15
 
16
+ # Reconstruimos la arquitectura ConvNeXt Tiny
17
+ learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
 
 
18
 
19
+ # Cargamos solo los pesos (.pth)
20
+ # Fastai busca el archivo weights_path + '.pth'
21
+ learn.load(weights_path)
22
  return learn
23
 
24
+ # 2. Inicializar (Cargará checkpoint_1.pth)
25
  learn = load_clean_model('checkpoint_1')
26
 
27
  def predict(img):
 
29
  pred, pred_idx, probs = learn.predict(img)
30
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
31
 
32
+ # 3. Interfaz de Gradio
33
  demo = gr.Interface(
34
  fn=predict,
35
  inputs=gr.Image(type="pil"),
36
  outputs=gr.Label(num_top_classes=3),
37
+ title="Pokemon Type Classifier",
38
+ description="Modelo ConvNeXt Tiny cargado mediante pesos (.pth)"
39
  )
40
 
41
  demo.launch()