Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- aimodel.py +15 -6
- requirements.txt +2 -1
aimodel.py
CHANGED
|
@@ -6,6 +6,8 @@ import uuid
|
|
| 6 |
import re
|
| 7 |
from fastai.vision.all import *
|
| 8 |
from pathlib import Path
|
|
|
|
|
|
|
| 9 |
import random
|
| 10 |
import shutil
|
| 11 |
|
|
@@ -94,16 +96,23 @@ def train_model():
|
|
| 94 |
batch_tfms=Normalize.from_stats(*imagenet_stats)
|
| 95 |
)
|
| 96 |
dls = dblock.dataloaders(dataset_path)
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
if os.path.exists(utils.MODEL_PATH):
|
| 100 |
try:
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
-
logging.error(f"Error loading
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
learn.fine_tune(5)
|
| 109 |
learn.export(utils.MODEL_PATH)
|
|
|
|
| 6 |
import re
|
| 7 |
from fastai.vision.all import *
|
| 8 |
from pathlib import Path
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
import random
|
| 12 |
import shutil
|
| 13 |
|
|
|
|
| 96 |
batch_tfms=Normalize.from_stats(*imagenet_stats)
|
| 97 |
)
|
| 98 |
dls = dblock.dataloaders(dataset_path)
|
| 99 |
+
num_classes = len(dls.vocab)
|
| 100 |
+
learn = vision_learner(dls, resnet18, metrics=accuracy, pretrained=True) # Começar com pesos pré-treinados
|
| 101 |
|
| 102 |
if os.path.exists(utils.MODEL_PATH):
|
| 103 |
try:
|
| 104 |
+
# Carregar apenas os pesos do backbone (todas as camadas exceto a última linear)
|
| 105 |
+
state = torch.load(utils.MODEL_PATH)
|
| 106 |
+
# Remover a chave da camada de classificação (geralmente '1.fc' ou similar)
|
| 107 |
+
new_state = {k: v for k, v in state['model'].items() if not k.endswith('.weight') and not k.endswith('.bias')}
|
| 108 |
+
learn.model.load_state_dict(new_state, strict=False)
|
| 109 |
+
logging.info("Backbone weights loaded.")
|
| 110 |
except Exception as e:
|
| 111 |
+
logging.error(f"Error loading backbone weights: {e}. Training from scratch.")
|
| 112 |
+
|
| 113 |
+
# Substituir a camada de classificação
|
| 114 |
+
num_ftrs = learn.model.fc.in_features
|
| 115 |
+
learn.model.fc = nn.Linear(num_ftrs, num_classes)
|
| 116 |
|
| 117 |
learn.fine_tune(5)
|
| 118 |
learn.export(utils.MODEL_PATH)
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ requests
|
|
| 5 |
numpy
|
| 6 |
opencv-python-headless
|
| 7 |
python-multipart
|
| 8 |
-
fastai
|
|
|
|
|
|
| 5 |
numpy
|
| 6 |
opencv-python-headless
|
| 7 |
python-multipart
|
| 8 |
+
fastai
|
| 9 |
+
torch
|