Edupy commited on
Commit
a739378
·
1 Parent(s): c0926f4

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +27 -35
app.py CHANGED
@@ -1,64 +1,56 @@
1
 
2
- import gradio as gr
3
  import gradio as gr
4
  import torch
5
  import timm
6
  from PIL import Image
7
  from torchvision import transforms
 
8
 
9
- # 1. Lista de categorías (Asegúrate de que el orden sea el mismo que en tu vocab de Colab)
10
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
11
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
12
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
13
 
14
- # 2. Carga del modelo directamente en PyTorch
15
- def load_model(weights_path):
16
- # Creamos la arquitectura exacta
17
- model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
 
18
 
19
- # Cargamos los pesos del archivo .pth
20
- # weights_only=False es necesario para cargar el estado guardado por fastai
21
- state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
22
 
23
- # Si los pesos vienen de un Learner de fastai, suelen estar dentro de una llave 'model'
24
- if 'model' in state_dict:
25
- state_dict = state_dict['model']
 
 
26
 
27
- # Cargamos los pesos en la arquitectura
28
- model.load_state_dict(state_dict, strict=False)
29
- model.eval() # Modo evaluación
30
- return model
31
 
32
- # Cargamos el modelo (el archivo debe estar en el repo como checkpoint_1.pth)
33
- model = load_model('checkpoint_1.pth')
34
 
35
- # 3. Función de predicción manual
36
  def predict(img):
37
- # Preprocesamiento: Resize(126) + Convertir a Tensor + Normalización ImageNet
38
- # Esto es exactamente lo que hace fastai por debajo
39
- transform = transforms.Compose([
40
- transforms.Resize((126, 126)),
41
- transforms.ToTensor(),
42
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
43
- ])
44
 
45
- # Convertimos la imagen de Gradio a Tensor
46
- img_tensor = transform(img).unsqueeze(0) # Añadir dimensión de batch
47
 
48
- with torch.no_grad():
49
- outputs = model(img_tensor)
50
- probs = torch.nn.functional.softmax(outputs[0], dim=0)
51
 
52
- # Retornamos diccionario con las probabilidades
53
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
54
 
55
- # 4. Interfaz de Gradio
56
  demo = gr.Interface(
57
  fn=predict,
58
  inputs=gr.Image(type="pil"),
59
  outputs=gr.Label(num_top_classes=3),
60
- title="Pokemon Type Classifier",
61
- description="Inferencia directa usando PyTorch y ConvNeXt Tiny."
62
  )
63
 
64
  demo.launch()
 
1
 
 
2
  import gradio as gr
3
  import torch
4
  import timm
5
  from PIL import Image
6
  from torchvision import transforms
7
+ from fastai.vision.all import *
8
 
9
+ # 1. Categorías (Asegúrate de que el orden sea ALFABÉTICO, que es el defecto de Fastai)
10
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
11
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
12
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
13
 
14
+ def load_model_fastai_style(weights_path):
15
+ # Creamos un Learner vacío para que Fastai construya la arquitectura COMPLETA
16
+ # (Cuerpo de ConvNeXt + Cabeza de clasificación de Fastai)
17
+ dls = DataLoaders.from_empty(categories)
18
+ learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
19
 
20
+ # Cargamos el archivo .pth
21
+ # Usamos torch.load directamente para mayor control
22
+ state = torch.load(weights_path, map_location='cpu', weights_only=False)
23
 
24
+ # Si guardaste con learn.save, los pesos están en state['model']
25
+ if isinstance(state, dict) and 'model' in state:
26
+ learn.model.load_state_dict(state['model'])
27
+ else:
28
+ learn.model.load_state_dict(state)
29
 
30
+ learn.model.eval()
31
+ return learn
 
 
32
 
33
+ # Cargamos el modelo
34
+ learn = load_model_fastai_style('checkpoint_1.pth')
35
 
 
36
  def predict(img):
37
+ # Usamos PILImage de fastai para que el preprocesamiento sea IDÉNTICO a Colab
38
+ img = PILImage.create(img)
 
 
 
 
 
39
 
40
+ # IMPORTANTE: Forzamos el resize al tamaño de entrenamiento
41
+ img = img.resize((126, 126))
42
 
43
+ # Usamos el método predict del learner, que ya sabe normalizar
44
+ pred, pred_idx, probs = learn.predict(img)
 
45
 
 
46
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
47
 
 
48
  demo = gr.Interface(
49
  fn=predict,
50
  inputs=gr.Image(type="pil"),
51
  outputs=gr.Label(num_top_classes=3),
52
+ title="Pokemon Type Classifier (Sincronizado)",
53
+ description="Inferencia con arquitectura completa de Fastai."
54
  )
55
 
56
  demo.launch()