Noveramaaz commited on
Commit
8d4a0f7
·
verified ·
1 Parent(s): 11d15c8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -33
main.py CHANGED
@@ -1,31 +1,3 @@
1
- from contextlib import asynccontextmanager
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel, ValidationError
4
- from fastapi.encoders import jsonable_encoder
5
-
6
- # TEXT PREPROCESSING
7
- # --------------------------------------------------------------------
8
- import re
9
- import string
10
- import nltk
11
- nltk.download('punkt')
12
- nltk.download('wordnet')
13
- nltk.download('omw-1.4')
14
- from nltk.stem import WordNetLemmatizer
15
-
16
- # Function to remove URLs from text
17
- def remove_urls(text):
18
- return re.sub(r'http[s]?://\S+', '', text)
19
-
20
- # Function to remove punctuations from text
21
- def remove_punctuation(text):
22
- regular_punct = string.punctuation
23
- return str(re.sub(r'['+regular_punct+']', '', str(text)))
24
-
25
- # Function to convert the text into lower case
26
- def lower_case(text):
27
- return text.lower()
28
-
29
  # Function to lemmatize text
30
  def lemmatize(text):
31
  wordnet_lemmatizer = WordNetLemmatizer()
@@ -45,7 +17,7 @@ def preprocess_text(text):
45
  text = lemmatize(text)
46
  return text
47
 
48
- # Load the model using FastAPI lifespan event so that the model is loaded at the beginning for efficiency
49
  @asynccontextmanager
50
  async def lifespan(app: FastAPI):
51
  # Load the model from HuggingFace transformers library
@@ -56,8 +28,14 @@ async def lifespan(app: FastAPI):
56
  # Clean up the model and release the resources
57
  del sentiment_task
58
 
 
 
 
 
 
 
59
  # Initialize the FastAPI app
60
- app = FastAPI(lifespan=lifespan)
61
 
62
  # Define the input data model
63
  class TextInput(BaseModel):
@@ -66,13 +44,39 @@ class TextInput(BaseModel):
66
  # Define the welcome endpoint
67
  @app.get('/')
68
  async def welcome():
69
- # Redirect to the Swagger UI page
70
- return RedirectResponse(url="/docs")
71
-
72
 
73
  # Validate input text length
74
  MAX_TEXT_LENGTH = 1000
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Define the sentiment analysis endpoint
77
  @app.post('/analyze/{text}')
78
  async def classify_text(text_input:TextInput):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Function to lemmatize text
2
  def lemmatize(text):
3
  wordnet_lemmatizer = WordNetLemmatizer()
 
17
  text = lemmatize(text)
18
  return text
19
 
20
+ # Load the model using FastAPI lifespan event so that teh model is loaded at the beginning for efficiency
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  # Load the model from HuggingFace transformers library
 
28
  # Clean up the model and release the resources
29
  del sentiment_task
30
 
31
+ description = """
32
+ ## Text Classification API
33
+ This app shows the sentiment of the text (positive, negative, or neutral).
34
+ Check out the docs for the `/analyze/{text}` endpoint below to try it out!
35
+ """
36
+
37
  # Initialize the FastAPI app
38
+ app = FastAPI(lifespan=lifespan, docs_url="/", description=description)
39
 
40
  # Define the input data model
41
  class TextInput(BaseModel):
 
44
  # Define the welcome endpoint
45
  @app.get('/')
46
  async def welcome():
47
+ return "Welcome to our Text Classification API"
 
 
48
 
49
  # Validate input text length
50
  MAX_TEXT_LENGTH = 1000
51
 
52
+ # Define the sentiment analysis endpoint
53
+ @app.post('/analyze/{text}')
54
+ async def classify_text(text_input:TextInput):
55
+ try:
56
+ # Convert input data to JSON serializable dictionary
57
+ text_input_dict = jsonable_encoder(text_input)
58
+ # Validate input data using Pydantic model
59
+ text_data = TextInput(**text_input_dict) # Convert to Pydantic model
60
+
61
+ # Validate input text length
62
+ if len(text_input.text) > MAX_TEXT_LENGTH:
63
+ raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length")
64
+ elif len(text_input.text) == 0:
65
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
66
+ except ValidationError as e:
67
+ # Handle validation error
68
+ raise HTTPException(status_code=422, detail=str(e))
69
+
70
+ try:
71
+ # Perform text classification
72
+ return sentiment_task(preprocess_text(text_input.text))
73
+ except ValueError as ve:
74
+ # Handle value error
75
+ raise HTTPException(status_code=400, detail=str(ve))
76
+ except Exception as e:
77
+ # Handle other server errors
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
  # Define the sentiment analysis endpoint
81
  @app.post('/analyze/{text}')
82
  async def classify_text(text_input:TextInput):