lahiruchamika27 commited on
Commit
303983a
·
verified ·
1 Parent(s): 712792c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -40
app.py CHANGED
@@ -1,49 +1,12 @@
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from transformers import T5ForConditionalGeneration, T5Tokenizer
4
- import torch
5
- import nest_asyncio
6
  import uvicorn
7
 
8
  # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- # Load pre-trained T5 model and tokenizer for paraphrasing
12
- model_name = "t5-small" # You can use "t5-base" or other variants for better performance
13
- tokenizer = T5Tokenizer.from_pretrained(model_name)
14
- model = T5ForConditionalGeneration.from_pretrained(model_name)
15
-
16
- # Define request body model
17
- class ParaphraseRequest(BaseModel):
18
- text: str
19
-
20
- # Define response model
21
- class ParaphraseResponse(BaseModel):
22
- paraphrased_text: str
23
-
24
- # Function to generate paraphrased text
25
- def paraphrase_text(input_text: str) -> str:
26
- try:
27
- # Prepare input for T5 model (prepend "paraphrase:" to the input text)
28
- input_ids = tokenizer.encode("paraphrase: " + input_text, return_tensors="pt", max_length=512, truncation=True)
29
-
30
- # Generate paraphrased text
31
- outputs = model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
32
-
33
- # Decode the generated text
34
- paraphrased_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return paraphrased_text
36
- except Exception as e:
37
- raise HTTPException(status_code=500, detail=f"Error during paraphrasing: {str(e)}")
38
-
39
- # API endpoint for paraphrasing
40
- @app.post("/paraphrase", response_model=ParaphraseResponse)
41
- async def paraphrase(request: ParaphraseRequest):
42
- paraphrased_text = paraphrase_text(request.text)
43
- return {"paraphrased_text": paraphrased_text}
44
-
45
- # Apply nest_asyncio to allow nested event loops
46
- nest_asyncio.apply()
47
 
48
  # Run the app with Uvicorn (use this command in terminal: uvicorn your_script_name:app --reload)
49
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, HTTPException
 
 
 
 
2
  import uvicorn
3
 
4
  # Initialize FastAPI app
5
  app = FastAPI()
6
 
7
+ @app.route('/')
8
+ def index():
9
+ return "Hello"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Run the app with Uvicorn (use this command in terminal: uvicorn your_script_name:app --reload)
12
  if __name__ == "__main__":