File size: 1,027 Bytes
7373e7f
 
 
d60c505
7373e7f
74972a7
64e9fa2
7373e7f
 
0ece6f3
7373e7f
eb17928
 
d60c505
 
 
 
 
 
 
7373e7f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import gradio as gr
import torch
from transformers import pipeline
import re

#pipeline = pipeline(task="text-generation", model="jeevana/GenerativeQnASystem")
pipeline = pipeline(model="jeevana/GenerativeQnASystem", max_new_tokens=60)

def predict(input):
    print("pipeline object", pipeline)
    prediction = pipeline(input)
    prediction = prediction[0].get("generated_text")
    print("1:::", prediction)
    prediction = prediction[len(input):]
    pattern = re.compile(r'\bAnswer:|\bAnswer\b', re.IGNORECASE)

    # Use sub() to replace the matched pattern with an empty string
    result = pattern.sub('', prediction)

    return result.strip()


app = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Question", lines=3)],
                    outputs=[gr.Textbox(label="Answer", lines=6)],
                    title="Generative QnA System",
                    description="Generative QnA with GPT2"
                   )
app.launch(share=True, debug=True)


# gr.load("models/jeevana/GenerativeQnASystem").launch()