AnnasBlackHat commited on
Commit
cebad5c
·
1 Parent(s): c49a9ad

enable multiple image outputs

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -9,7 +9,7 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- ## RUN Gradio Locally
13
  ```
14
  pip install gradio
15
  gradio app.py
 
9
  pinned: false
10
  ---
11
 
12
+ ## RUN Gradio Locally (With Auto Reload)
13
  ```
14
  pip install gradio
15
  gradio app.py
app.py CHANGED
@@ -2,34 +2,48 @@ import gradio as gr
2
  import requests
3
  import random
4
  from src.classification_model import ClassificationModel
 
5
 
6
  #only for dummy data
7
- response = requests.get("https://git.io/JJkYN")
8
- labels = response.text.split("\n")
9
 
10
  clf = ClassificationModel()
11
  model_names = clf.get_model_names()
12
  output_labels = []
 
 
13
 
14
- def predict(models, img_urls, img_files):
15
  print(f'model choosen: {models}')
16
  model_predictions = {}
17
 
18
  #set all labels visibility to false
19
- for i, name in enumerate(model_names):
20
- model_predictions[output_labels[i]] = gr.Label(label=f'# {name}', visible=False)
21
- print(f'id {i} invisible')
 
 
22
 
23
- for m in models:
24
- idx = model_names.index(m)
25
- print(f' {m} idx: ', idx)
26
- result = {labels[random.randrange(0, len(labels))]: random.uniform(0, 1.0) for i in range(5)}
27
- model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=result, visible=True)
 
 
 
 
 
 
 
 
28
 
29
  return model_predictions
30
 
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# Image Classification Benchmark")
 
33
 
34
  with gr.Row():
35
  with gr.Column(scale=1):
@@ -38,12 +52,14 @@ with gr.Blocks() as demo:
38
  img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image'])
39
  apply = gr.Button("Classify", variant='primary')
40
  with gr.Column(scale=1):
41
- for name in clf.get_model_names():
42
- output_labels.append(gr.Label(label=f'# {name}'))
 
 
43
 
44
  apply.click(fn=predict,
45
  inputs=[model, img_urls, img_files],
46
- outputs=output_labels)
47
 
48
 
49
  if __name__ == "__main__":
 
2
  import requests
3
  import random
4
  from src.classification_model import ClassificationModel
5
+ from src.util.extract import extract_image_urls
6
 
7
  #only for dummy data
8
+ # response = requests.get("https://git.io/JJkYN")
9
+ # labels = response.text.split("\n")
10
 
11
  clf = ClassificationModel()
12
  model_names = clf.get_model_names()
13
  output_labels = []
14
+ output_images = []
15
+ max_input_image = 10
16
 
17
+ def predict(models, img_url, img_files):
18
  print(f'model choosen: {models}')
19
  model_predictions = {}
20
 
21
  #set all labels visibility to false
22
+ for label in output_labels:
23
+ model_predictions[label] = gr.Label(label=f'# {name}', visible=False)
24
+ #set all images visibility yo hidden
25
+ for img in output_images:
26
+ model_predictions[img] = gr.Image(visible=False)
27
 
28
+ sources = extract_image_urls(img_url) + (img_files or [])
29
+ for i, source in enumerate(sources):
30
+ print(f'{i} type: {type(source)} --> {source}')
31
+ if i >= max_input_image: break
32
+
33
+ for j, m in enumerate(models):
34
+ results = clf.classify(m, source)
35
+ print(f'{m} --> {results}')
36
+
37
+ idx = j + (len(model_names)*i) #getting index of label
38
+ label_value = {raw.class_name: raw.confidence for raw in results}
39
+ model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=label_value, visible=True)
40
+ model_predictions[output_images[i]] = gr.Image(visible=True, value=source, label=f'image {i}') # set image visibility to true
41
 
42
  return model_predictions
43
 
44
  with gr.Blocks() as demo:
45
  gr.Markdown("# Image Classification Benchmark")
46
+ gr.Markdown("You can input at maximum 10 images at once (urls or files)")
47
 
48
  with gr.Row():
49
  with gr.Column(scale=1):
 
52
  img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image'])
53
  apply = gr.Button("Classify", variant='primary')
54
  with gr.Column(scale=1):
55
+ for i in range(max_input_image):
56
+ output_images.append(gr.Image(interactive=False, visible= (i==0)))
57
+ for name in clf.get_model_names():
58
+ output_labels.append(gr.Label(label=f'# {name}', visible= (i==0)))
59
 
60
  apply.click(fn=predict,
61
  inputs=[model, img_urls, img_files],
62
+ outputs=output_images+output_labels)
63
 
64
 
65
  if __name__ == "__main__":
src/__pycache__/classification_model.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/classification_model.cpython-312.pyc and b/src/__pycache__/classification_model.cpython-312.pyc differ
 
src/__pycache__/model_data.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/model_data.cpython-312.pyc and b/src/__pycache__/model_data.cpython-312.pyc differ
 
src/classification_model.py CHANGED
@@ -1,4 +1,10 @@
1
- from .model_data import ModelData
 
 
 
 
 
 
2
 
3
  class ClassificationModel:
4
  """
@@ -6,7 +12,7 @@ class ClassificationModel:
6
  """
7
 
8
  def __init__(self):
9
- self.models = self.initialize_models()
10
 
11
  def get_model_names(self):
12
  return [model.name for model in self.models]
@@ -17,14 +23,24 @@ class ClassificationModel:
17
  return model
18
  raise Exception(f'Model {model_name} not found')
19
 
20
- def initialize_models(self):
21
- return [
22
- ModelData('clip-vit-base-patch32'),
23
- ModelData('mobilenet_v3')
24
  ]
25
 
26
- def load_model(self):
27
- """
28
- Loads the model from the model path.
29
- """
30
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from urllib.request import urlopen
3
+ from PIL import Image
4
+ from .data.model_data import ModelData
5
+ from .models.mobilenet_v3 import MobilenetV3
6
+ from .models.clip_vit import ClipVit
7
+ from .data.classification_result import ClassificationResult
8
 
9
  class ClassificationModel:
10
  """
 
12
  """
13
 
14
  def __init__(self):
15
+ self.load_model()
16
 
17
  def get_model_names(self):
18
  return [model.name for model in self.models]
 
23
  return model
24
  raise Exception(f'Model {model_name} not found')
25
 
26
+ def load_model(self):
27
+ self.models = [
28
+ ModelData('clip-vit-base-patch32', model_class=ClipVit()),
29
+ ModelData('mobilenet_v3', model_class=MobilenetV3())
30
  ]
31
 
32
+ def classify(self, model_name, image) -> List[ClassificationResult]:
33
+ #print type of image
34
+ print('>> image type -->',type(image))
35
+
36
+ #convert image to pil
37
+ img = self.image_to_pil(image)
38
+
39
+ model = self.get_model_data(model_name)
40
+ return model.model_class.classify_image(img)
41
+
42
+ def image_to_pil(self, image):
43
+ #if image is starts with https (means url), then download it
44
+ if image.startswith('https'):
45
+ return Image.open(urlopen(image))
46
+ return Image.open(image)
src/data/classification_result.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class ClassificationResult:
5
+ class_name: str
6
+ confidence: float
src/{model_data.py → data/model_data.py} RENAMED
@@ -1,5 +1,7 @@
1
  from dataclasses import dataclass
 
2
 
3
  @dataclass
4
  class ModelData:
5
  name: str
 
 
1
  from dataclasses import dataclass
2
+ from src.interface import ModelInterface
3
 
4
  @dataclass
5
  class ModelData:
6
  name: str
7
+ model_class: ModelInterface
src/interface.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .data.classification_result import ClassificationResult
2
+ from abc import ABC, abstractmethod
3
+ from typing import List
4
+
5
+ class ModelInterface(ABC):
6
+ @abstractmethod
7
+ def classify_image(self, image) -> List[ClassificationResult]:
8
+ pass
src/models/clip_vit.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src.interface import ModelInterface
3
+ from src.data.classification_result import ClassificationResult
4
+
5
+ class ClipVit(ModelInterface):
6
+ def __init__(self):
7
+ print('init... vlip vit model')
8
+
9
+ def classify_image(self, image) -> List[ClassificationResult]:
10
+ class_name = "Example Result"
11
+ confidence = 0.85
12
+ return [ClassificationResult(class_name=class_name, confidence=confidence)]
src/models/mobilenet_v3.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import random
3
+ from src.interface import ModelInterface
4
+ from src.data.classification_result import ClassificationResult
5
+
6
+ class MobilenetV3(ModelInterface):
7
+
8
+ def __init__(self):
9
+ print('init... mobilenet v3 model')
10
+
11
+ def classify_image(self, image) -> List[ClassificationResult]:
12
+ results = [ClassificationResult(class_name=f'example class ({i+1})', confidence=random.uniform(0, 1.0)) for i in range(5)]
13
+ return results
src/util/extract.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def extract_image_urls(text):
4
+ # Regular expression to match image URLs
5
+ pattern = re.compile(r'https?://[^\s]+\.jpg|https?://[^\s]+\.jpeg|https?://[^\s]+\.png|https?://[^\s]+\.gif|https?://[^\s]+\.bmp|https?://[^\s]+\.webp')
6
+
7
+ # Find all matches in the input text
8
+ matches = pattern.findall(text)
9
+
10
+ return matches