SakibAhmed commited on
Commit
3e9ac54
·
verified ·
1 Parent(s): 1405aab

Upload 2 files

Browse files
Files changed (2) hide show
  1. .env +5 -1
  2. app.py +82 -50
.env CHANGED
@@ -1 +1,5 @@
1
- MODEL_NAME=best_new_EP382.pt
 
 
 
 
 
1
+ # Name of the first model (e.g., your original classifier)
2
+ MODEL_1_NAME=best_88E.pt
3
+
4
+ # Name of the second model (for Tyre/Alloy classification)
5
+ MODEL_2_NAME=best_TA_377EP.pt
app.py CHANGED
@@ -18,41 +18,73 @@ CORS(app)
18
 
19
  # --- Configuration ---
20
  UPLOAD_FOLDER = 'static/uploads'
21
- MODELS_FOLDER = 'models' # New folder for models
22
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
23
 
24
- # Load model name from .env file, with a fallback default
25
- MODEL_NAME = os.getenv('MODEL_NAME', 'best.pt')
26
- MODEL_PATH = os.path.join(MODELS_FOLDER, MODEL_NAME)
 
 
 
27
 
28
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
29
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
30
- os.makedirs(MODELS_FOLDER, exist_ok=True) # Ensure models folder exists
31
- os.makedirs('templates', exist_ok=True) # Ensure templates folder exists
32
 
33
- # --- Determine Device and Load YOLO Model ---
34
- # Use CUDA if available, otherwise use CPU
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  print(f"Using device: {device}")
37
 
38
- # Load the model once when the application starts for efficiency.
39
- model = None
 
 
40
  try:
41
- if not os.path.exists(MODEL_PATH):
42
- print(f"Error: Model file not found at {MODEL_PATH}")
43
- print("Please make sure the model file exists and the MODEL_NAME in your .env file is correct.")
44
  else:
45
- model = YOLO(MODEL_PATH)
46
- model.to(device) # Move model to the selected device
47
- print(f"Successfully loaded model '{MODEL_NAME}' on {device}.")
48
  except Exception as e:
49
- print(f"Error loading YOLO model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def allowed_file(filename):
52
  """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
53
  return '.' in filename and \
54
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @app.route('/')
57
  def home():
58
  """Serve the main HTML page."""
@@ -61,59 +93,59 @@ def home():
61
  @app.route('/predict', methods=['POST'])
62
  def predict():
63
  """
64
- Endpoint to receive an image, run YOLO classification, and return the single best prediction.
65
  """
66
- if model is None:
67
- return jsonify({"error": "Model could not be loaded. Please check server logs."}), 500
68
-
69
  # 1. --- File Validation ---
70
  if 'file' not in request.files:
71
  return jsonify({"error": "No file part in the request"}), 400
72
-
73
  file = request.files['file']
74
  if file.filename == '':
75
  return jsonify({"error": "No selected file"}), 400
76
-
77
  if not file or not allowed_file(file.filename):
78
  return jsonify({"error": "File type not allowed"}), 400
79
 
 
 
 
80
  # 2. --- Save the File Temporarily ---
81
  filename = secure_filename(file.filename)
82
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
83
  file.save(filepath)
84
 
85
- # 3. --- Perform Inference ---
86
  try:
87
- # Run the YOLO model on the uploaded image. The model is already on the correct device.
88
- results = model(filepath)
89
-
90
- # 4. --- Process Results to Get ONLY the Top Prediction ---
91
- # Get the first result object from the list
92
- result = results[0]
93
-
94
- # Access the probabilities object
95
- probs = result.probs
96
-
97
- # Get the index and confidence of the top prediction
98
- top1_index = probs.top1
99
- top1_confidence = float(probs.top1conf) # Convert tensor to Python float
100
-
101
- # Get the class name from the model's 'names' dictionary
102
- class_name = model.names[top1_index]
103
-
104
- # Create the final prediction object
105
- prediction = {
106
- "class": class_name,
107
- "confidence": top1_confidence
108
- }
109
-
110
- # Return the single prediction object as JSON
111
- return jsonify(prediction)
 
 
112
 
113
  except Exception as e:
114
  return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
115
  finally:
116
- # 5. --- Cleanup ---
117
  if os.path.exists(filepath):
118
  os.remove(filepath)
119
 
 
18
 
19
  # --- Configuration ---
20
  UPLOAD_FOLDER = 'static/uploads'
21
+ MODELS_FOLDER = 'models'
22
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
23
 
24
+ # --- NEW: Load model names from .env file, with fallback defaults ---
25
+ MODEL_1_NAME = os.getenv('MODEL_1_NAME', 'best.pt')
26
+ MODEL_2_NAME = os.getenv('MODEL_2_NAME', 'tyre_alloy.pt') # New model for Tyre/Alloy
27
+
28
+ MODEL_1_PATH = os.path.join(MODELS_FOLDER, MODEL_1_NAME)
29
+ MODEL_2_PATH = os.path.join(MODELS_FOLDER, MODEL_2_NAME)
30
 
31
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
32
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
33
+ os.makedirs(MODELS_FOLDER, exist_ok=True)
34
+ os.makedirs('templates', exist_ok=True)
35
 
36
+ # --- Determine Device ---
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
  print(f"Using device: {device}")
39
 
40
+ # --- NEW: Load multiple YOLO Models ---
41
+ model1, model2 = None, None
42
+
43
+ # Load Model 1
44
  try:
45
+ if not os.path.exists(MODEL_1_PATH):
46
+ print(f"Warning: Model file not found at {MODEL_1_PATH}")
 
47
  else:
48
+ model1 = YOLO(MODEL_1_PATH)
49
+ model1.to(device)
50
+ print(f"Successfully loaded model '{MODEL_1_NAME}' on {device}.")
51
  except Exception as e:
52
+ print(f"Error loading Model 1 ({MODEL_1_NAME}): {e}")
53
+
54
+ # Load Model 2
55
+ try:
56
+ if not os.path.exists(MODEL_2_PATH):
57
+ print(f"Warning: Model file not found at {MODEL_2_PATH}")
58
+ else:
59
+ model2 = YOLO(MODEL_2_PATH)
60
+ model2.to(device)
61
+ print(f"Successfully loaded model '{MODEL_2_NAME}' on {device}.")
62
+ except Exception as e:
63
+ print(f"Error loading Model 2 ({MODEL_2_NAME}): {e}")
64
+
65
 
66
  def allowed_file(filename):
67
  """Checks if a file's extension is in the ALLOWED_EXTENSIONS set."""
68
  return '.' in filename and \
69
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
70
 
71
+ def run_inference(model, filepath):
72
+ """Helper function to run inference and format the result."""
73
+ if model is None:
74
+ return None # Return None if the model isn't loaded
75
+
76
+ results = model(filepath)
77
+ result = results[0]
78
+ probs = result.probs
79
+ top1_index = probs.top1
80
+ top1_confidence = float(probs.top1conf)
81
+ class_name = model.names[top1_index]
82
+
83
+ return {
84
+ "class": class_name,
85
+ "confidence": top1_confidence
86
+ }
87
+
88
  @app.route('/')
89
  def home():
90
  """Serve the main HTML page."""
 
93
  @app.route('/predict', methods=['POST'])
94
  def predict():
95
  """
96
+ Endpoint to receive an image and run classification based on the requested model type.
97
  """
 
 
 
98
  # 1. --- File Validation ---
99
  if 'file' not in request.files:
100
  return jsonify({"error": "No file part in the request"}), 400
 
101
  file = request.files['file']
102
  if file.filename == '':
103
  return jsonify({"error": "No selected file"}), 400
 
104
  if not file or not allowed_file(file.filename):
105
  return jsonify({"error": "File type not allowed"}), 400
106
 
107
+ # --- NEW: Get the model type from the form data ---
108
+ model_type = request.form.get('model_type', 'model1') # default to model1
109
+
110
  # 2. --- Save the File Temporarily ---
111
  filename = secure_filename(file.filename)
112
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
113
  file.save(filepath)
114
 
115
+ # 3. --- Perform Inference based on model_type ---
116
  try:
117
+ if model_type == 'model1':
118
+ if model1 is None:
119
+ return jsonify({"error": f"Model '{MODEL_1_NAME}' is not loaded. Check server logs."}), 500
120
+ prediction = run_inference(model1, filepath)
121
+ return jsonify(prediction)
122
+
123
+ elif model_type == 'model2':
124
+ if model2 is None:
125
+ return jsonify({"error": f"Model '{MODEL_2_NAME}' is not loaded. Check server logs."}), 500
126
+ prediction = run_inference(model2, filepath)
127
+ return jsonify(prediction)
128
+
129
+ elif model_type == 'combined':
130
+ if model1 is None or model2 is None:
131
+ return jsonify({"error": "One or more models required for combined mode are not loaded. Check server logs."}), 500
132
+
133
+ pred1 = run_inference(model1, filepath)
134
+ pred2 = run_inference(model2, filepath)
135
+
136
+ combined_prediction = {
137
+ "model1_result": pred1,
138
+ "model2_result": pred2
139
+ }
140
+ return jsonify(combined_prediction)
141
+
142
+ else:
143
+ return jsonify({"error": "Invalid model type specified"}), 400
144
 
145
  except Exception as e:
146
  return jsonify({"error": f"An error occurred during inference: {str(e)}"}), 500
147
  finally:
148
+ # 4. --- Cleanup ---
149
  if os.path.exists(filepath):
150
  os.remove(filepath)
151