|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
|
|
translator = pipeline(task="translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
def process_image(image, shouldConvert=False): |
|
|
if shouldConvert: |
|
|
new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255)) |
|
|
new_img.paste(image, (0, 0), mask=image) |
|
|
image = new_img |
|
|
return image |
|
|
|
|
|
def parse_input(image, sketchpad, state): |
|
|
current_tab_index = state["tab_index"] |
|
|
new_image = None |
|
|
|
|
|
|
|
|
if current_tab_index == 0: |
|
|
if image is not None: |
|
|
new_image = process_image(image) |
|
|
|
|
|
elif current_tab_index == 1: |
|
|
|
|
|
if sketchpad and sketchpad["composite"]: |
|
|
new_image = process_image(sketchpad["composite"], True) |
|
|
|
|
|
|
|
|
Eng_txt = pipe(new_image) |
|
|
to_Ar_txt = str(Eng_txt[0]['generated_text']) |
|
|
text_translated = translator(to_Ar_txt, src_lang="eng_Latn", tgt_lang="arz_Arab") |
|
|
return Eng_txt[0]['generated_text'], text_translated[0]['translation_text'] |
|
|
|
|
|
|
|
|
def tabs_select(e: gr.SelectData, _state): |
|
|
_state["tab_index"] = e.index |
|
|
|
|
|
example_img_paths = [["https://4.img-dpreview.com/files/p/E~TS590x0~articles/3925134721/0266554465.jpeg"], |
|
|
["https://images4.alphacoders.com/688/688832.jpg"]] |
|
|
|
|
|
with gr.Blocks() as iface: |
|
|
gr.HTML("""<p align="center"><img src="https://cdn-icons-png.flaticon.com/512/5853/5853758.png" style="height: 60px"/><p>""") |
|
|
gr.HTML("""<center><font size=8>Image Captioning Demo</center>""") |
|
|
gr.HTML("""<center><font size=3>In this space you can input either an image or draw a sketch of object to recieve an Arabic caption.</center>""") |
|
|
state = gr.State({"tab_index": 0}) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Tabs() as input_tabs: |
|
|
with gr.Tab("Upload"): |
|
|
input_image = gr.Image(type="pil", label="Upload") |
|
|
with gr.Tab("Sketch"): |
|
|
input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False) |
|
|
input_tabs.select(fn=tabs_select, inputs=[state]) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Example Prompts") |
|
|
gr.Examples( |
|
|
example_img_paths, |
|
|
inputs=[input_image], |
|
|
cache_examples=False) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
clear_btn = gr.ClearButton( |
|
|
[input_image, input_sketchpad]) |
|
|
with gr.Column(): |
|
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
submit_btn.click( |
|
|
fn=parse_input, |
|
|
inputs=[input_image, input_sketchpad, state], |
|
|
outputs= [gr.Textbox(label = "English Result"), gr.Textbox(label = "Arabic Result")]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |