File size: 1,286 Bytes
44213d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c240e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
from transformers import pipeline

# pipeline_en = pipeline(task="text2text-generation", model="beyond/genius-large")

pipeline_zh = pipeline(task="text2text-generation", model="beyond/genius-base-chinese")

def predict_en(sketch):
  generated_text = pipeline_en(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
  return generated_text

def predict_zh(sketch):
  generated_text = pipeline_zh(sketch, num_beams=3, do_sample=True, max_length=200)[0]['generated_text']
  return generated_text
  
 
with gr.Blocks() as demo:
    gr.Markdown(
        """
        💡GENIUS – generating text using sketches! [Visit our github repo](https://github.com/beyondguo/genius)
        The English version use `<mask>` as the mask token.
        """)

    with gr.Row():
        with gr.Column():
            model_input = gr.Textbox(lines=7, placeholder='Input your sketch', label='Input')
            with gr.Row():
                gen = gr.Button("Generate")
                clr = gr.Button("Clear")
               
        outputs = gr.Textbox(lines=7, label='Output')
        
    gen.click(fn=predict_zh, inputs=[model_input], outputs=outputs)
    clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
    
demo.launch()