oudayhn commited on
Commit
007d0d3
·
verified ·
1 Parent(s): 727a06f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -1,13 +1,22 @@
1
- import gradio as gr
2
  from fastapi import FastAPI
 
3
 
4
- app = FastAPI()
 
 
 
5
 
 
6
  def paraphrase(text):
7
  input_text = "paraphrase: " + text + " </s>"
8
  inputs = tokenizer([input_text], return_tensors="pt", padding=True)
9
  outputs = model.generate(**inputs, max_length=100, num_beams=5, num_return_sequences=1)
10
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
11
 
12
- io = gr.Interface(fn=paraphrase, inputs="text", outputs="text")
13
- app = gr.mount_gradio_app(app, io, path="/") # now accepts POST on /
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
  from fastapi import FastAPI
3
+ import gradio as gr
4
 
5
+ # Load model and tokenizer
6
+ model_name = "Vamsi/T5_Paraphrase_Paws"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
+ # Paraphrasing function
11
  def paraphrase(text):
12
  input_text = "paraphrase: " + text + " </s>"
13
  inputs = tokenizer([input_text], return_tensors="pt", padding=True)
14
  outputs = model.generate(**inputs, max_length=100, num_beams=5, num_return_sequences=1)
15
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
16
 
17
+ # Create Gradio interface
18
+ interface = gr.Interface(fn=paraphrase, inputs="text", outputs="text")
19
+
20
+ # Mount Gradio to FastAPI for API calls
21
+ app = FastAPI()
22
+ app = gr.mount_gradio_app(app, interface, path="/") # now POST to / is accepted