Spaces:
Runtime error
Runtime error
Usando from_pretrained_fastai para carga limpia
Browse files
app.py
CHANGED
|
@@ -9,21 +9,19 @@ categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
|
|
| 9 |
'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
|
| 10 |
|
| 11 |
def load_clean_model(weights_path):
|
| 12 |
-
#
|
| 13 |
-
# Esto
|
| 14 |
-
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
dls = DataLoaders.from_empty(categories)
|
| 19 |
-
learn = Learner(dls, body, loss_func=CrossEntropyLossFlat())
|
| 20 |
|
| 21 |
-
# Cargamos los pesos (.pth)
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
return learn
|
| 25 |
|
| 26 |
-
#
|
| 27 |
learn = load_clean_model('checkpoint_1')
|
| 28 |
|
| 29 |
def predict(img):
|
|
@@ -31,13 +29,13 @@ def predict(img):
|
|
| 31 |
pred, pred_idx, probs = learn.predict(img)
|
| 32 |
return {categories[i]: float(probs[i]) for i in range(len(categories))}
|
| 33 |
|
| 34 |
-
# Interfaz
|
| 35 |
demo = gr.Interface(
|
| 36 |
fn=predict,
|
| 37 |
inputs=gr.Image(type="pil"),
|
| 38 |
outputs=gr.Label(num_top_classes=3),
|
| 39 |
-
title="
|
| 40 |
-
description="
|
| 41 |
)
|
| 42 |
|
| 43 |
demo.launch()
|
|
|
|
| 9 |
'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
|
| 10 |
|
| 11 |
def load_clean_model(weights_path):
|
| 12 |
+
# En lugar de from_empty, creamos dls mínimos compatibles
|
| 13 |
+
# Esto define la estructura necesaria para que learn.predict funcione
|
| 14 |
+
dls = ImageDataLoaders.from_lists('.', [], [], vocab=categories, item_tfms=Resize(126))
|
| 15 |
|
| 16 |
+
# Reconstruimos la arquitectura ConvNeXt Tiny
|
| 17 |
+
learn = vision_learner(dls, 'convnext_tiny', pretrained=False)
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# Cargamos solo los pesos (.pth)
|
| 20 |
+
# Fastai busca el archivo weights_path + '.pth'
|
| 21 |
+
learn.load(weights_path)
|
| 22 |
return learn
|
| 23 |
|
| 24 |
+
# 2. Inicializar (Cargará checkpoint_1.pth)
|
| 25 |
learn = load_clean_model('checkpoint_1')
|
| 26 |
|
| 27 |
def predict(img):
|
|
|
|
| 29 |
pred, pred_idx, probs = learn.predict(img)
|
| 30 |
return {categories[i]: float(probs[i]) for i in range(len(categories))}
|
| 31 |
|
| 32 |
+
# 3. Interfaz de Gradio
|
| 33 |
demo = gr.Interface(
|
| 34 |
fn=predict,
|
| 35 |
inputs=gr.Image(type="pil"),
|
| 36 |
outputs=gr.Label(num_top_classes=3),
|
| 37 |
+
title="Pokemon Type Classifier",
|
| 38 |
+
description="Modelo ConvNeXt Tiny cargado mediante pesos (.pth)"
|
| 39 |
)
|
| 40 |
|
| 41 |
demo.launch()
|