Inoue1 commited on
Commit
4418807
·
verified ·
1 Parent(s): e6eb87f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +87 -0
model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
+ from torchvision import models, transforms
8
+
9
+ BASE_DIR = os.path.dirname(__file__)
10
+
11
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
12
+ LABELS_DIR = os.path.join(BASE_DIR, "labels")
13
+
14
+ # ================= IMAGE PREPROCESS =================
15
+
16
+ def preprocess_pytorch(img, size=224):
17
+ transform = transforms.Compose([
18
+ transforms.Resize((size, size)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
+ return transform(img).unsqueeze(0)
26
+
27
+ def preprocess_keras(img, size=224):
28
+ img = img.resize((size, size))
29
+ arr = np.array(img) / 255.0
30
+ return np.expand_dims(arr, axis=0)
31
+
32
+ # ================= MODEL LOADERS =================
33
+
34
+ PYTORCH_MODELS = {}
35
+ KERAS_MODELS = {}
36
+ LABELS = {}
37
+
38
+ def load_models():
39
+ for file in os.listdir(MODELS_DIR):
40
+ name, ext = os.path.splitext(file)
41
+ model_path = os.path.join(MODELS_DIR, file)
42
+
43
+ # Load labels
44
+ with open(os.path.join(LABELS_DIR, f"{name}.json")) as f:
45
+ LABELS[name] = json.load(f)
46
+
47
+ if ext == ".pth":
48
+ num_classes = len(LABELS[name])
49
+ model = models.resnet18(weights=None)
50
+ model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
51
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
52
+ model.eval()
53
+ PYTORCH_MODELS[name] = model
54
+
55
+ elif ext == ".keras":
56
+ KERAS_MODELS[name] = tf.keras.models.load_model(model_path)
57
+
58
+ # Load once
59
+ load_models()
60
+
61
+ # ================= PREDICT =================
62
+
63
+ def predict(image, crop_name):
64
+ crop_name = crop_name.lower()
65
+
66
+ if crop_name in PYTORCH_MODELS:
67
+ model = PYTORCH_MODELS[crop_name]
68
+ labels = LABELS[crop_name]
69
+
70
+ tensor = preprocess_pytorch(image)
71
+ with torch.no_grad():
72
+ output = model(tensor)
73
+ probs = torch.softmax(output[0], dim=0)
74
+ idx = probs.argmax().item()
75
+ return labels[idx], float(probs[idx])
76
+
77
+ elif crop_name in KERAS_MODELS:
78
+ model = KERAS_MODELS[crop_name]
79
+ labels = LABELS[crop_name]
80
+
81
+ arr = preprocess_keras(image)
82
+ preds = model.predict(arr)[0]
83
+ idx = np.argmax(preds)
84
+ return labels[idx], float(preds[idx])
85
+
86
+ else:
87
+ raise ValueError(f"No model found for crop: {crop_name}")