File size: 3,200 Bytes
a294799
 
 
 
 
 
 
 
 
7bea74a
a294799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bea74a
a294799
 
 
78b4a3d
a294799
7bea74a
a294799
 
 
58a455e
 
04cb248
7bea74a
a294799
 
58a455e
a294799
7bea74a
a294799
 
 
 
 
 
 
 
042c428
7bea74a
 
04cb248
 
7abcea1
04cb248
7bea74a
042c428
a294799
 
 
 
 
 
 
 
 
78b4a3d
04cb248
a294799
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from transformers import pipeline
import torch 
from PIL import Image

# Load models
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")  
translator = pipeline(task="translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16)

#process sketch input 
def process_image(image, shouldConvert=False):
    if shouldConvert:
        new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
        new_img.paste(image, (0, 0), mask=image)
        image = new_img
    return image

def parse_input(image, sketchpad, state):
    current_tab_index = state["tab_index"]
    new_image = None
    
    # Upload
    if current_tab_index == 0:
        if image is not None:
            new_image = process_image(image)
    # Sketch
    elif current_tab_index == 1:
        #print(sketchpad)
        if sketchpad and sketchpad["composite"]:
            new_image = process_image(sketchpad["composite"], True)
    
    # send to pipeline
    Eng_txt = pipe(new_image)
    to_Ar_txt = str(Eng_txt[0]['generated_text'])
    text_translated = translator(to_Ar_txt, src_lang="eng_Latn", tgt_lang="arz_Arab")
    return Eng_txt[0]['generated_text'], text_translated[0]['translation_text']

#select tab 
def tabs_select(e: gr.SelectData, _state):
    _state["tab_index"] = e.index

example_img_paths = [["https://4.img-dpreview.com/files/p/E~TS590x0~articles/3925134721/0266554465.jpeg"], 
                     ["https://images4.alphacoders.com/688/688832.jpg"]]

with gr.Blocks() as iface:
    gr.HTML("""<p align="center"><img src="https://cdn-icons-png.flaticon.com/512/5853/5853758.png" style="height: 60px"/><p>""")
    gr.HTML("""<center><font size=8>Image Captioning Demo</center>""")
    gr.HTML("""<center><font size=3>In this space you can input either an image or draw a sketch of object to recieve an Arabic caption.</center>""")
    state = gr.State({"tab_index": 0})

    with gr.Row():
        with gr.Column():
            with gr.Tabs() as input_tabs:
                with gr.Tab("Upload"):
                    input_image = gr.Image(type="pil", label="Upload")
                with gr.Tab("Sketch"):
                    input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
            input_tabs.select(fn=tabs_select, inputs=[state])

            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Example Prompts")
                    gr.Examples(
                        example_img_paths,
                        inputs=[input_image],
                        cache_examples=False)
            
            with gr.Row():
                with gr.Column():
                    clear_btn = gr.ClearButton(
                        [input_image, input_sketchpad])
                with gr.Column():
                    submit_btn = gr.Button("Submit", variant="primary")
        submit_btn.click(
            fn=parse_input,
            inputs=[input_image, input_sketchpad, state],
            outputs= [gr.Textbox(label = "English Result"), gr.Textbox(label = "Arabic Result")]) 

# Launch the interface
if __name__ == "__main__":
    iface.launch()