File size: 839 Bytes
01b4830
4aca552
 
 
 
4ad1def
4aca552
 
 
07c0361
7e57d2d
1e46059
07c0361
4aca552
 
 
207da04
 
4aca552
2ec7e44
4aca552
 
01b4830
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import gradio as gr
import torch
from transformers import pipeline

#pipeline = pipeline(task="text-generation", model="EmailSubjectLineGeneration")
pipeline = pipeline(model="jeevana/EmailSubjectLineGeneration", max_new_tokens=20)

def predict(input):
    print("pipeline object", pipeline)
    prediction = pipeline(input+'\n@subject\n')
    prediction = prediction[0].get("generated_text")
    print("1:::", prediction)
    prediction = prediction[len(input)+len('\n@subject\n'):]
    return prediction 


app = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Email", lines=15)],
                    outputs=[gr.Textbox(label="Subject", lines=15)],
                    title="EmailSubjectLineGeneration",
                    description="EmailSubjectLineGeneration with GPT2"
                   )
app.launch(share=True, debug=True)