MarioPrzBasto commited on
Commit
cb804a9
·
verified ·
1 Parent(s): 91ae6f3

Upload 2 files

Browse files
Files changed (2) hide show
  1. aimodel.py +15 -6
  2. requirements.txt +2 -1
aimodel.py CHANGED
@@ -6,6 +6,8 @@ import uuid
6
  import re
7
  from fastai.vision.all import *
8
  from pathlib import Path
 
 
9
  import random
10
  import shutil
11
 
@@ -94,16 +96,23 @@ def train_model():
94
  batch_tfms=Normalize.from_stats(*imagenet_stats)
95
  )
96
  dls = dblock.dataloaders(dataset_path)
97
- learn = vision_learner(dls, resnet18, metrics=accuracy) # Criar um novo Learner
 
98
 
99
  if os.path.exists(utils.MODEL_PATH):
100
  try:
101
- learn.load(utils.MODEL_PATH.replace(".pkl", "")) # Carregar apenas os pesos (sem o otimizador e outros estados)
102
- logging.info(f"Existing model weights loaded from {utils.MODEL_PATH}")
 
 
 
 
103
  except Exception as e:
104
- logging.error(f"Error loading existing model weights: {e}. Training from scratch.")
105
- else:
106
- logging.info("No existing model found. Training from scratch.")
 
 
107
 
108
  learn.fine_tune(5)
109
  learn.export(utils.MODEL_PATH)
 
6
  import re
7
  from fastai.vision.all import *
8
  from pathlib import Path
9
+ import torch
10
+ import torch.nn as nn
11
  import random
12
  import shutil
13
 
 
96
  batch_tfms=Normalize.from_stats(*imagenet_stats)
97
  )
98
  dls = dblock.dataloaders(dataset_path)
99
+ num_classes = len(dls.vocab)
100
+ learn = vision_learner(dls, resnet18, metrics=accuracy, pretrained=True) # Começar com pesos pré-treinados
101
 
102
  if os.path.exists(utils.MODEL_PATH):
103
  try:
104
+ # Carregar apenas os pesos do backbone (todas as camadas exceto a última linear)
105
+ state = torch.load(utils.MODEL_PATH)
106
+ # Remover a chave da camada de classificação (geralmente '1.fc' ou similar)
107
+ new_state = {k: v for k, v in state['model'].items() if not k.endswith('.weight') and not k.endswith('.bias')}
108
+ learn.model.load_state_dict(new_state, strict=False)
109
+ logging.info("Backbone weights loaded.")
110
  except Exception as e:
111
+ logging.error(f"Error loading backbone weights: {e}. Training from scratch.")
112
+
113
+ # Substituir a camada de classificação
114
+ num_ftrs = learn.model.fc.in_features
115
+ learn.model.fc = nn.Linear(num_ftrs, num_classes)
116
 
117
  learn.fine_tune(5)
118
  learn.export(utils.MODEL_PATH)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ requests
5
  numpy
6
  opencv-python-headless
7
  python-multipart
8
- fastai
 
 
5
  numpy
6
  opencv-python-headless
7
  python-multipart
8
+ fastai
9
+ torch