gabrielmotablima commited on
Commit
aefcd17
·
verified ·
1 Parent(s): 6c815da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -1,55 +1,48 @@
1
  import requests
2
- from PIL import Image, UnidentifiedImageError
3
  from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
4
  import gradio as gr
5
  import os
 
6
 
7
  # Load the model, tokenizer, and image processor with error handling
8
  def load_model_and_components(model_name):
9
- try:
10
- model = VisionEncoderDecoderModel.from_pretrained(model_name)
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- image_processor = AutoImageProcessor.from_pretrained(model_name)
13
- return model, tokenizer, image_processor
14
- except Exception as e:
15
- raise RuntimeError(f"Error loading model components: {e}")
16
 
17
- # Preload both models
18
  def preload_models():
19
  models = {}
20
- models["laicsiifes/swin-distilbertimbau"] = load_model_and_components("laicsiifes/swin-distilbertimbau")
21
- models["laicsiifes/swin-gportuguese-2"] = load_model_and_components("laicsiifes/swin-gportuguese-2")
 
 
 
22
  return models
23
 
24
  models = preload_models()
25
- current_model_name = "laicsiifes/swin-distilbertimbau"
26
- model, tokenizer, image_processor = models[current_model_name]
27
 
28
  # Function to process the image and generate a caption
29
  def generate_caption(image, model_name):
30
- try:
31
- model, tokenizer, image_processor = models[model_name]
32
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
33
- generated_ids = model.generate(pixel_values)
34
- caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
- return caption
36
- except Exception:
37
- return "Please upload a valid image."
38
 
39
  # Predefined images for selection
40
  image_folder = "images"
41
  predefined_images_paths = [
42
- os.path.join(image_folder, fname) for fname in os.listdir(image_folder) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))
43
  ]
44
 
45
  # Gradio app
46
  def app(image=None, model_name="laicsiifes/swin-distilbertimbau"):
47
- try:
48
- if image is None:
49
- return "Please upload a valid image."
50
- return generate_caption(image, model_name)
51
- except Exception:
52
  return "Please upload a valid image."
 
53
 
54
  # Define UI
55
  with gr.Blocks() as interface:
@@ -62,27 +55,34 @@ with gr.Blocks() as interface:
62
  """)
63
  with gr.Row():
64
  with gr.Column():
65
- model_selector = gr.Dropdown(choices=list(models.keys()),
66
- value="laicsiifes/swin-distilbertimbau",
67
- label="Select Model")
68
- image_display = gr.Image(type="pil", label="Image Preview", interactive=False)
 
 
 
69
  upload_button = gr.File(label="Upload an Image", file_types=["image"], type="filepath")
70
  examples = gr.Examples(predefined_images_paths, inputs=[upload_button], label="Examples")
71
-
 
72
  with gr.Column():
73
  output_text = gr.Textbox(label="Generated Caption")
74
-
75
  # Define logic
76
- def handle_uploaded_image(image, selected_model):
77
- try:
78
- if image is None:
79
- return None, "Please upload a valid image."
80
- pil_image = Image.open(image).convert("RGB")
81
- return pil_image, generate_caption(pil_image, selected_model)
82
- except Exception:
83
- return None, "Please upload a valid image."
 
 
84
 
85
  model_selector.change(fn=lambda _: (None, None, None), inputs=[model_selector], outputs=[image_display, upload_button, output_text])
86
- upload_button.change(fn=handle_uploaded_image, inputs=[upload_button, model_selector], outputs=[image_display, output_text])
 
87
 
88
- interface.launch(share=False)
 
1
  import requests
2
+ from PIL import Image
3
  from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
4
  import gradio as gr
5
  import os
6
+ from concurrent.futures import ThreadPoolExecutor
7
 
8
  # Load the model, tokenizer, and image processor with error handling
9
  def load_model_and_components(model_name):
10
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
13
+ return model, tokenizer, image_processor
 
 
 
14
 
15
+ # Preload both models in parallel
16
  def preload_models():
17
  models = {}
18
+ model_names = ["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"]
19
+ with ThreadPoolExecutor() as executor:
20
+ results = executor.map(load_model_and_components, model_names)
21
+ for name, result in zip(model_names, results):
22
+ models[name] = result
23
  return models
24
 
25
  models = preload_models()
 
 
26
 
27
  # Function to process the image and generate a caption
28
  def generate_caption(image, model_name):
29
+ model, tokenizer, image_processor = models[model_name]
30
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
31
+ generated_ids = model.generate(pixel_values, max_length=30, num_beams=2)
32
+ caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
+ return caption
 
 
 
34
 
35
  # Predefined images for selection
36
  image_folder = "images"
37
  predefined_images_paths = [
38
+ os.path.join(image_folder, fname) for fname in os.listdir(image_folder) if fname.lower().endswith(('.png', '.jpg', '.jpeg'))
39
  ]
40
 
41
  # Gradio app
42
  def app(image=None, model_name="laicsiifes/swin-distilbertimbau"):
43
+ if image is None:
 
 
 
 
44
  return "Please upload a valid image."
45
+ return generate_caption(image, model_name)
46
 
47
  # Define UI
48
  with gr.Blocks() as interface:
 
55
  """)
56
  with gr.Row():
57
  with gr.Column():
58
+ model_selector = gr.Dropdown(
59
+ choices=list(models.keys()),
60
+ value="laicsiifes/swin-distilbertimbau",
61
+ label="Select Model"
62
+ )
63
+ with gr.Row():
64
+ with gr.Column():
65
  upload_button = gr.File(label="Upload an Image", file_types=["image"], type="filepath")
66
  examples = gr.Examples(predefined_images_paths, inputs=[upload_button], label="Examples")
67
+ image_display = gr.Image(type="pil", label="Image Preview", interactive=False)
68
+ generate_button = gr.Button("Generate")
69
  with gr.Column():
70
  output_text = gr.Textbox(label="Generated Caption")
71
+
72
  # Define logic
73
+ def handle_uploaded_image(image):
74
+ if image is None:
75
+ return None
76
+ pil_image = Image.open(image).convert("RGB")
77
+ return pil_image
78
+
79
+ def handle_generate_button(image, selected_model):
80
+ if image is None:
81
+ return "Please upload an image to generate a caption."
82
+ return generate_caption(image, selected_model)
83
 
84
  model_selector.change(fn=lambda _: (None, None, None), inputs=[model_selector], outputs=[image_display, upload_button, output_text])
85
+ upload_button.change(fn=handle_uploaded_image, inputs=upload_button, outputs=image_display)
86
+ generate_button.click(fn=handle_generate_button, inputs=[image_display, model_selector], outputs=output_text)
87
 
88
+ interface.launch(share=False)