Edupy commited on
Commit
919ec96
·
1 Parent(s): 593614a

Usando from_pretrained_fastai para carga limpia

Browse files
Files changed (1) hide show
  1. app.py +30 -33
app.py CHANGED
@@ -1,46 +1,43 @@
1
 
2
  import gradio as gr
3
  from fastai.vision.all import *
4
- from huggingface_hub import from_pretrained_fastai
5
- import torch, os
6
- import __main__
7
-
8
- # 1. REDEFINIMOS LAS FUNCIONES AQUÍ MISMO
9
- # Estas son las funciones que el modelo "recuerda" de tu entrenamiento
10
- def get_x(i): return None
11
- def get_y(i): return None
12
-
13
- # 2. LAS ASIGNAMOS AL MÓDULO PRINCIPAL
14
- # Esto es lo que permite que load_learner las encuentre
15
- __main__.get_x = get_x
16
- __main__.get_y = get_y
17
-
18
- # Configuraciones de rendimiento para el servidor
19
- os.environ.setdefault("OMP_NUM_THREADS", "1")
20
- torch.set_num_threads(1)
21
-
22
- # 3. CARGAMOS EL MODELO
23
- # Ahora from_pretrained_fastai no dará el error "res"
24
- learn = from_pretrained_fastai("Edupy/pokemon-1class-classifier-26")
25
-
26
- # Pasar a FP32 por si el modelo se guardó en semi-precisión (ahorra errores en CPU)
27
- try: learn.to_fp32()
28
- except: pass
29
-
30
- labels = learn.dls.vocab
31
 
32
  def predict(img):
33
- img = PILImage.create(img)
34
  pred, pred_idx, probs = learn.predict(img)
35
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
36
 
37
- # 4. INTERFAZ DE GRADIO
38
  demo = gr.Interface(
39
  fn=predict,
40
  inputs=gr.Image(type="pil"),
41
  outputs=gr.Label(num_top_classes=3),
42
- title="Detector de Tipos Pokémon",
43
- description="Sube una imagen de un Pokémon para predecir su tipo principal."
44
  )
45
 
46
- demo.queue(max_size=8).launch(show_error=True, debug=True)
 
1
 
2
  import gradio as gr
3
  from fastai.vision.all import *
4
+ import timm
5
+
6
+ # 1. Lista de categorías exacta
7
+ categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
8
+ 'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
9
+ 'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
10
+
11
+ def load_clean_model(weights_path):
12
+ # Creamos la arquitectura ConvNeXt Tiny directamente desde timm
13
+ # Esto no busca ni get_x, ni get_y, ni archivos dummy
14
+ body = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
15
+
16
+ # Creamos un Learner "vacío" solo para usar su método predict
17
+ # DataLoaders.from_empty es la clave: no busca archivos en el disco
18
+ dls = DataLoaders.from_empty(categories)
19
+ learn = Learner(dls, body, loss_func=CrossEntropyLossFlat())
20
+
21
+ # Cargamos los pesos (.pth)
22
+ # Si tu archivo se llama checkpoint_1.pth, aquí ponemos 'checkpoint_1'
23
+ load_model(weights_path, learn.model, learn.opt)
24
+ return learn
25
+
26
+ # Cargamos el modelo (asegúrate de que checkpoint_1.pth esté en el repo)
27
+ learn = load_clean_model('checkpoint_1')
 
 
 
28
 
29
  def predict(img):
30
+ img = PILImage.create(img).resize((126, 126))
31
  pred, pred_idx, probs = learn.predict(img)
32
+ return {categories[i]: float(probs[i]) for i in range(len(categories))}
33
 
34
+ # Interfaz
35
  demo = gr.Interface(
36
  fn=predict,
37
  inputs=gr.Image(type="pil"),
38
  outputs=gr.Label(num_top_classes=3),
39
+ title="Detector Pokémon (Sin Dependencias)",
40
+ description="Carga directa de arquitectura y pesos."
41
  )
42
 
43
+ demo.launch()