Edupy commited on
Commit
5dfab75
·
1 Parent(s): 057bb80

Forzando sobrescritura de app.py y carga de pesos .pth

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -4,44 +4,45 @@ import gradio as gr
4
  import timm
5
  import torch
6
 
7
- # 1. Lista de categorías exacta
8
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
10
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
11
 
12
- # 2. Reconstruir el modelo con limpieza de llaves (keys)
13
  def load_pokemon_model(weights_path):
14
- model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
15
- state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
16
-
17
- if 'model' in state_dict:
18
- state_dict = state_dict['model']
19
-
20
- new_state_dict = {}
21
- for k, v in state_dict.items():
22
- name = k.replace('0.model.', '') if k.startswith('0.model.') else k
23
- new_state_dict[name] = v
24
-
25
- model.load_state_dict(new_state_dict, strict=False)
26
- model.eval()
27
- return model
28
 
29
- # 3. Preparación de Inferencia
30
  model = load_pokemon_model('checkpoint_1.pth')
31
 
 
32
  def predict(img):
33
- # Procesamiento manual de la imagen para evitar depender de DataLoaders complejos
34
- img = PILImage.create(img).resize((224, 224))
35
- img_tensor = Pipeline([ToTensor(), IntToFloatTensor()])(img)
36
-
37
- # Añadir dimensión de batch y pasar por el modelo
38
- with torch.no_grad():
39
- output = model(img_tensor.unsqueeze(0))
40
- probs = torch.softmax(output, dim=1)[0]
41
-
42
- return {categories[i]: float(probs[i]) for i in range(len(categories))}
 
 
 
 
 
 
43
 
44
- # 4. Interfaz de Gradio
45
  gr.Interface(
46
  fn=predict,
47
  inputs=gr.Image(),
 
4
  import timm
5
  import torch
6
 
7
+ # 1. Lista de categorías
8
  categories = ['Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting',
9
  'Fire', 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice',
10
  'Normal', 'Poison', 'Psychic', 'Rock', 'Steel', 'Water']
11
 
12
+ # 2. Carga del modelo
13
  def load_pokemon_model(weights_path):
14
+ try:
15
+ model = timm.create_model('convnext_tiny', pretrained=False, num_classes=len(categories))
16
+ state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
17
+ if 'model' in state_dict: state_dict = state_dict['model']
18
+ new_state_dict = {k.replace('0.model.', ''): v for k, v in state_dict.items()}
19
+ model.load_state_dict(new_state_dict, strict=False)
20
+ model.eval()
21
+ return model
22
+ except Exception as e:
23
+ return f"Error cargando pesos: {str(e)}"
 
 
 
 
24
 
 
25
  model = load_pokemon_model('checkpoint_1.pth')
26
 
27
+ # 3. Función de predicción con captura de errores
28
  def predict(img):
29
+ try:
30
+ if isinstance(model, str): return {"Error de carga": model} # Si falló el paso 2
31
+
32
+ img = PILImage.create(img).resize((224, 224))
33
+
34
+ # Transformación manual a tensor
35
+ img_tensor = cast(ToTensor()(img), TensorImage)
36
+ img_tensor = IntToFloatTensor()(img_tensor)
37
+
38
+ with torch.no_grad():
39
+ output = model(img_tensor.unsqueeze(0))
40
+ probs = torch.softmax(output, dim=1)[0]
41
+
42
+ return {categories[i]: float(probs[i]) for i in range(len(categories))}
43
+ except Exception as e:
44
+ return {"Error en predicción": str(e)}
45
 
 
46
  gr.Interface(
47
  fn=predict,
48
  inputs=gr.Image(),