Edupy commited on
Commit
190cb6d
·
1 Parent(s): 5c9db49

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -9,32 +9,29 @@ categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
12
- # Creamos un DataBlock manualmente.
13
- # Esto define la estructura sin intentar leer ninguna lista.
14
- dblock = DataBlock(
15
- blocks=(ImageBlock, CategoryBlock(vocab=categories)),
16
- item_tfms=Resize(126)
17
- )
18
-
19
- # Creamos un DataLoader vacío que solo sirve para inicializar el Learner
20
- # Usamos dls de un objeto dummy que no requiere archivos reales
21
- dls = dblock.dataloaders([PILImage.create(np.zeros((126,126,3), dtype=np.uint8))],
22
- [categories[0]], bs=1)
23
-
24
  # Reconstruimos la arquitectura ConvNeXt Tiny
25
- learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
 
 
 
 
 
26
 
27
  # Cargamos los pesos (.pth)
28
- # Fastai busca el archivo weights_path + '.pth'
29
  learn.load(weights_path)
30
  return learn
31
 
32
- # 2. Inicializar (Cargará checkpoint_1.pth)
33
  learn = load_clean_model('checkpoint_1')
34
 
35
  def predict(img):
 
36
  img = PILImage.create(img).resize((126, 126))
 
 
37
  pred, pred_idx, probs = learn.predict(img)
 
38
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
39
 
40
  # 3. Interfaz de Gradio
 
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
  def load_clean_model(weights_path):
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Reconstruimos la arquitectura ConvNeXt Tiny
13
+ # num_classes debe coincidir con el número de categorías
14
+ model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
15
+
16
+ # Creamos un Learner mínimo. Al usar dls.from_empty no hay procesos de carga de datos.
17
+ dls = DataLoaders.from_empty(categories)
18
+ learn = Learner(dls, model, loss_func=CrossEntropyLossFlat())
19
 
20
  # Cargamos los pesos (.pth)
21
+ # Fastai buscará el archivo weights_path + '.pth' (checkpoint_1.pth)
22
  learn.load(weights_path)
23
  return learn
24
 
25
+ # 2. Inicializar el modelo
26
  learn = load_clean_model('checkpoint_1')
27
 
28
  def predict(img):
29
+ # PILImage.create prepara la imagen para Fastai
30
  img = PILImage.create(img).resize((126, 126))
31
+
32
+ # El método predict de Learner aplicará las transformaciones necesarias
33
  pred, pred_idx, probs = learn.predict(img)
34
+
35
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
36
 
37
  # 3. Interfaz de Gradio