JonSnow1512 commited on
Commit
d8a1c8d
·
verified ·
1 Parent(s): 63f74c4

Upload 5 files

Browse files
Files changed (5) hide show
  1. config.py +5 -0
  2. label_encoder.joblib +3 -0
  3. model.py +26 -0
  4. predict.py +63 -0
  5. saved_model.joblib +3 -0
config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ TRAIN_DIR = "C:\\Users\\shrey\\Downloads\\ucf_dataset\\Train"
2
+ TEST_DIR = "C:\\Users\\shrey\\Downloads\\ucf_dataset\\Test"
3
+ DEVICE = "cuda"
4
+ CLASSIFIER_TYPE = "mlp" # Options: "logistic", "random_forest", "mlp"
5
+ MODEL_SAVE_PATH = "saved_model.joblib"
label_encoder.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4efe2cb985d1d6f88344cad5d6c5a0120909dbd2fa24440578d2f19e12ab7c0e
3
+ size 1055
model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.linear_model import LogisticRegression
2
+ from sklearn.ensemble import RandomForestClassifier
3
+ from sklearn.neural_network import MLPClassifier
4
+ from sklearn.metrics import classification_report
5
+ import joblib
6
+ from config import MODEL_SAVE_PATH
7
+
8
+ def get_classifier(model_type):
9
+ if model_type == "logistic":
10
+ return LogisticRegression(max_iter=2000)
11
+ elif model_type == "random_forest":
12
+ return RandomForestClassifier(n_estimators=100)
13
+ elif model_type == "mlp":
14
+ return MLPClassifier(hidden_layer_sizes=(512, 256), max_iter=300)
15
+ else:
16
+ raise ValueError(f"Unknown model type: {model_type}")
17
+
18
+ def train_classifier(X_train, y_train, model_type):
19
+ clf = get_classifier(model_type)
20
+ clf.fit(X_train, y_train)
21
+ joblib.dump(clf, MODEL_SAVE_PATH)
22
+ return clf
23
+
24
+ def evaluate_classifier(model, X_test, y_test, label_encoder):
25
+ y_pred = model.predict(X_test)
26
+ print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))
predict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import joblib
4
+ import numpy as np
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from config import DEVICE, MODEL_SAVE_PATH
7
+ from flask import Flask, request, jsonify
8
+ from flask_cors import CORS
9
+ import os
10
+
11
+ app = Flask(__name__)
12
+ CORS(app)
13
+
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
15
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+
17
+ def predict_image(image_path):
18
+ image = Image.open(image_path).convert("RGB")
19
+ inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
20
+
21
+ with torch.no_grad():
22
+ image_features = clip_model.get_image_features(**inputs)
23
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
24
+ features = image_features.cpu().numpy()
25
+
26
+ model = joblib.load(MODEL_SAVE_PATH)
27
+ label_encoder = joblib.load("label_encoder.joblib")
28
+ pred = model.predict(features)
29
+ label = label_encoder.inverse_transform(pred)
30
+ return label[0]
31
+
32
+ @app.route('/predict', methods=['POST'])
33
+ def predict():
34
+ if 'image' not in request.files:
35
+ return jsonify({'error': 'No image uploaded'}), 400
36
+
37
+ image = request.files['image']
38
+ if image.filename == '':
39
+ return jsonify({'error': 'No image selected'}), 400
40
+
41
+ try:
42
+ # Save the uploaded image temporarily
43
+ image_path = "temp_image.jpg"
44
+ image.save(image_path)
45
+
46
+ # Predict the image
47
+ prediction = predict_image(image_path)
48
+
49
+ # Remove the temporary image
50
+ os.remove(image_path)
51
+
52
+ return jsonify({'prediction': prediction})
53
+
54
+ except Exception as e:
55
+ return jsonify({'error': str(e)}), 500
56
+
57
+ @app.route('/healthcheck', methods=['GET'])
58
+ def healthcheck():
59
+ return jsonify({'status': 'ok'}), 200
60
+
61
+ if __name__ == '__main__':
62
+ port = int(os.environ.get('PORT', 5000))
63
+ app.run(debug=True, host='0.0.0.0', port=port)
saved_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6989da95b96c941b3a5560d3d2c16b4bf4c16082059ccba58d8e418a0b9a8274
3
+ size 6370936