Edupy commited on
Commit
c0926f4
·
1 Parent(s): 657801f

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +38 -23
app.py CHANGED
@@ -1,49 +1,64 @@
1
 
2
  import gradio as gr
3
- from fastai.vision.all import *
 
4
  import timm
 
 
5
 
6
- # 1. Lista de categorías exacta
7
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
8
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
9
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
 
11
- def load_clean_model(weights_path):
12
- # Creamos la arquitectura ConvNeXt Tiny
 
13
  model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
14
 
15
- # FORMA COMPATIBLE: Creamos DataLoaders vacíos manualmente
16
- # Esto define que el vocabulario son nuestras categorías
17
- dls = DataLoaders(Datasets([], [[]]), vocab=categories)
18
-
19
- # Creamos el Learner
20
- learn = Learner(dls, model, loss_func=CrossEntropyLossFlat())
21
 
22
- # Cargamos los pesos (.pth)
23
- # Fastai busca el archivo weights_path + '.pth'
24
- learn.load(weights_path)
25
- return learn
 
 
 
 
26
 
27
- # 2. Inicializar el modelo
28
- learn = load_clean_model('checkpoint_1')
29
 
 
30
  def predict(img):
31
- # Preparar la imagen
32
- img = PILImage.create(img).resize((126, 126))
 
 
 
 
 
 
 
 
33
 
34
- # En inferencia pura (sin dblock), usamos el modelo directamente o learn.predict
35
- # learn.predict aplicará el preprocesamiento necesario
36
- pred, pred_idx, probs = learn.predict(img)
37
 
 
38
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
39
 
40
- # 3. Interfaz de Gradio
41
  demo = gr.Interface(
42
  fn=predict,
43
  inputs=gr.Image(type="pil"),
44
  outputs=gr.Label(num_top_classes=3),
45
  title="Pokemon Type Classifier",
46
- description="Inferencia robusta con Fastai y ConvNeXt Tiny."
47
  )
48
 
49
  demo.launch()
 
1
 
2
  import gradio as gr
3
+ import gradio as gr
4
+ import torch
5
  import timm
6
+ from PIL import Image
7
+ from torchvision import transforms
8
 
9
+ # 1. Lista de categorías (Asegúrate de que el orden sea el mismo que en tu vocab de Colab)
10
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
11
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
12
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
13
 
14
+ # 2. Carga del modelo directamente en PyTorch
15
+ def load_model(weights_path):
16
+ # Creamos la arquitectura exacta
17
  model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
18
 
19
+ # Cargamos los pesos del archivo .pth
20
+ # weights_only=False es necesario para cargar el estado guardado por fastai
21
+ state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
 
 
 
22
 
23
+ # Si los pesos vienen de un Learner de fastai, suelen estar dentro de una llave 'model'
24
+ if 'model' in state_dict:
25
+ state_dict = state_dict['model']
26
+
27
+ # Cargamos los pesos en la arquitectura
28
+ model.load_state_dict(state_dict, strict=False)
29
+ model.eval() # Modo evaluación
30
+ return model
31
 
32
+ # Cargamos el modelo (el archivo debe estar en el repo como checkpoint_1.pth)
33
+ model = load_model('checkpoint_1.pth')
34
 
35
+ # 3. Función de predicción manual
36
  def predict(img):
37
+ # Preprocesamiento: Resize(126) + Convertir a Tensor + Normalización ImageNet
38
+ # Esto es exactamente lo que hace fastai por debajo
39
+ transform = transforms.Compose([
40
+ transforms.Resize((126, 126)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
43
+ ])
44
+
45
+ # Convertimos la imagen de Gradio a Tensor
46
+ img_tensor = transform(img).unsqueeze(0) # Añadir dimensión de batch
47
 
48
+ with torch.no_grad():
49
+ outputs = model(img_tensor)
50
+ probs = torch.nn.functional.softmax(outputs[0], dim=0)
51
 
52
+ # Retornamos diccionario con las probabilidades
53
  return {categories[i]: float(probs[i]) for i in range(len(categories))}
54
 
55
+ # 4. Interfaz de Gradio
56
  demo = gr.Interface(
57
  fn=predict,
58
  inputs=gr.Image(type="pil"),
59
  outputs=gr.Label(num_top_classes=3),
60
  title="Pokemon Type Classifier",
61
+ description="Inferencia directa usando PyTorch y ConvNeXt Tiny."
62
  )
63
 
64
  demo.launch()