cfoli commited on
Commit
dfc66bf
·
verified ·
1 Parent(s): 06378e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -29
app.py CHANGED
@@ -57,25 +57,6 @@ LABELS_MAP = ["Bat (baseball)", "Bat (mammal)",
57
 
58
  """
59
 
60
- # model_key = "CLIP-large"
61
-
62
- # Load model (cache for speed)
63
- if model_key not in MODEL_CACHE:
64
- MODEL_CACHE[model_key] = pipeline(task = "zero-shot-image-classification",
65
- model = MODEL_OPTIONS[model_key])
66
- classifier = MODEL_CACHE[model_key]
67
-
68
- BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
69
- image_path = os.path.join(BASE_DIR, 'Mouse1_2.png')
70
-
71
- output = classifier(
72
- image = image_path,
73
- candidate_labels = CANDIDATE_LABELS,
74
- hypothesis_template = "This image shows {}")
75
-
76
- # print("\n\n=============================================================================")
77
- # print(f"\nPrediction: This image shows {output[0]["label"]} | Confidence (probability): {100*output[0]["score"]: .1f}%")
78
-
79
  def run_classifer(model_key, image_path, prob_threshold = None):
80
  # model_key: name of backbone zero-shot-image-classification model to use
81
  # image_path: path to test image
@@ -107,16 +88,6 @@ def run_classifer(model_key, image_path, prob_threshold = None):
107
 
108
  return predicted_label_str, prob_dict
109
 
110
- # # example run
111
- # model_key = "CLIP-large"
112
- # BASE_DIR = '/content/drive/MyDrive/ML Projects/Zero-shot Image Classification/Images'
113
- # image_path = os.path.join(BASE_DIR, 'Nail2_1.png')
114
-
115
- # predicted_label_str, prob_dict = run_classifer(model_key, image_path, prob_threshold = 0.4)
116
- # print("\n\n=============================================================================")
117
- # # print(f"\nPrediction: {predicted_label_str} | Confidence (probability): {100*output[0]['score']:.1f}%")
118
- # print(f"\nPrediction: {predicted_label_str}")
119
-
120
  """### Gradio App
121
 
122
  ---
 
57
 
58
  """
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def run_classifer(model_key, image_path, prob_threshold = None):
61
  # model_key: name of backbone zero-shot-image-classification model to use
62
  # image_path: path to test image
 
88
 
89
  return predicted_label_str, prob_dict
90
 
 
 
 
 
 
 
 
 
 
 
91
  """### Gradio App
92
 
93
  ---