the10or commited on
Commit
161324c
·
verified ·
1 Parent(s): 2debabe

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +2 -0
  2. inception_v3.keras +3 -0
  3. index.py +46 -0
  4. mn_model.keras +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ inception_v3.keras filter=lfs diff=lfs merge=lfs -text
37
+ mn_model.keras filter=lfs diff=lfs merge=lfs -text
inception_v3.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2d751d2d613deb54f2b8f8b97cff8eecd7c0f87ed7a548e8028af68b9024ab6
3
+ size 298791432
index.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import gradio as gr
3
+
4
+ from src.load_dataset.load_dataset import classes
5
+
6
+ nm_model = tf.keras.models.load_model("../models/mn_model.keras")
7
+
8
+ resnet_model = tf.keras.models.load_model("../models/newmodel.h5")
9
+
10
+ inception_model = tf.keras.models.load_model("../models/inception_v3.keras")
11
+
12
+ cifar10_labels = classes
13
+ models = ["ResNetBased Model", "MobileNetBased Model", "InceptionBased Model"]
14
+
15
+
16
+ def classify_image(input_image, model_name):
17
+ try:
18
+ input_image = tf.image.resize(input_image, (32, 32))
19
+ labels = cifar10_labels
20
+ model = get_model(model_name)
21
+ input_image = tf.expand_dims(input_image, axis=0)
22
+ predictions = model.predict(input_image).flatten()
23
+ top_indices = predictions.argsort()[-10:][::-1]
24
+ confidences = {labels[i]: float(predictions[i]) for i in top_indices}
25
+ return confidences
26
+ except Exception as e:
27
+ return {"error": str(e)}
28
+
29
+
30
+ def get_model(model_name):
31
+ if model_name == "MobileNetBased Model":
32
+ return nm_model
33
+ elif model_name == "ResNetBased Model":
34
+ return resnet_model
35
+ elif model_name == "InceptionBased Model":
36
+ return inception_model
37
+
38
+
39
+ interface = gr.Interface(
40
+ fn=classify_image,
41
+ inputs=[gr.Image(type="numpy", image_mode="RGB", label="Input Image"),
42
+ gr.Dropdown(models, label="Model Choice")],
43
+ outputs=gr.Label(num_top_classes=3, label="Predictions"),
44
+ )
45
+
46
+ interface.launch(debug=False, share=True)
mn_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d1da46963093525e7af3ea4afc97eb2fe4f30144812cf8c4cac2fe998d03976
3
+ size 73589384