| from contextlib import asynccontextmanager |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, ValidationError |
| from fastapi.encoders import jsonable_encoder |
|
|
| |
| |
| import re |
| import string |
| import nltk |
| nltk.download('punkt') |
| nltk.download('wordnet') |
| nltk.download('omw-1.4') |
| from nltk.stem import WordNetLemmatizer |
|
|
| |
| def remove_urls(text): |
| return re.sub(r'http[s]?://\S+', '', text) |
|
|
| |
| def remove_punctuation(text): |
| regular_punct = string.punctuation |
| return str(re.sub(r'['+regular_punct+']', '', str(text))) |
|
|
| |
| def lower_case(text): |
| return text.lower() |
|
|
| |
| def lemmatize(text): |
| wordnet_lemmatizer = WordNetLemmatizer() |
|
|
| tokens = nltk.word_tokenize(text) |
| lemma_txt = '' |
| for w in tokens: |
| lemma_txt = lemma_txt + wordnet_lemmatizer.lemmatize(w) + ' ' |
|
|
| return lemma_txt |
|
|
| def preprocess_text(text): |
| |
| text = remove_urls(text) |
| text = remove_punctuation(text) |
| text = lower_case(text) |
| text = lemmatize(text) |
| return text |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| from transformers import pipeline |
| global sentiment_task |
| sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest") |
| yield |
| |
| del sentiment_task |
|
|
| description = """ |
| ## Text Classification API |
| This app shows the sentiment of the text (positive, negative, or neutral). |
| Check out the docs for the `/analyze/{text}` endpoint below to try it out! |
| """ |
|
|
| |
| app = FastAPI(lifespan=lifespan, docs_url="/", description=description) |
|
|
| |
| class TextInput(BaseModel): |
| text: str |
|
|
| |
| @app.get('/') |
| async def welcome(): |
| |
| return RedirectResponse(url="/docs") |
| |
| |
| MAX_TEXT_LENGTH = 1000 |
|
|
| |
| @app.post('/analyze/{text}') |
| async def classify_text(text_input:TextInput): |
| try: |
| |
| text_input_dict = jsonable_encoder(text_input) |
| |
| text_data = TextInput(**text_input_dict) |
|
|
| |
| if len(text_input.text) > MAX_TEXT_LENGTH: |
| raise HTTPException(status_code=400, detail="Text length exceeds maximum allowed length") |
| elif len(text_input.text) == 0: |
| raise HTTPException(status_code=400, detail="Text cannot be empty") |
| except ValidationError as e: |
| |
| raise HTTPException(status_code=422, detail=str(e)) |
|
|
| try: |
| |
| return sentiment_task(preprocess_text(text_input.text)) |
| except ValueError as ve: |
| |
| raise HTTPException(status_code=400, detail=str(ve)) |
| except Exception as e: |
| |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|