Update main.py
Browse files
main.py
CHANGED
|
@@ -49,15 +49,26 @@ def preprocess_text(text):
|
|
| 49 |
@asynccontextmanager
|
| 50 |
async def lifespan(app: FastAPI):
|
| 51 |
# Load the model from HuggingFace transformers library
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
from transformers import pipeline
|
| 53 |
global sentiment_task
|
| 54 |
-
|
|
|
|
| 55 |
yield
|
| 56 |
# Clean up the model and release the resources
|
| 57 |
del sentiment_task
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Initialize the FastAPI app
|
| 60 |
-
app = FastAPI(lifespan=lifespan)
|
| 61 |
|
| 62 |
# Define the input data model
|
| 63 |
class TextInput(BaseModel):
|
|
@@ -66,7 +77,7 @@ class TextInput(BaseModel):
|
|
| 66 |
# Define the welcome endpoint
|
| 67 |
@app.get('/')
|
| 68 |
async def welcome():
|
| 69 |
-
return "Welcome to our
|
| 70 |
|
| 71 |
# Validate input text length
|
| 72 |
MAX_TEXT_LENGTH = 1000
|
|
|
|
| 49 |
@asynccontextmanager
|
| 50 |
async def lifespan(app: FastAPI):
|
| 51 |
# Load the model from HuggingFace transformers library
|
| 52 |
+
# from transformers import pipeline
|
| 53 |
+
|
| 54 |
+
# sentiment_task = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", tokenizer="cardiffnlp/twitter-roberta-base-sentiment-latest")
|
| 55 |
+
# Use a pipeline as a high-level helper
|
| 56 |
from transformers import pipeline
|
| 57 |
global sentiment_task
|
| 58 |
+
pipe = pipeline("text-classification", model="SamLowe/roberta-base-go_emotions","SamLowe/roberta-base-go_emotions")
|
| 59 |
+
|
| 60 |
yield
|
| 61 |
# Clean up the model and release the resources
|
| 62 |
del sentiment_task
|
| 63 |
|
| 64 |
+
description = """
|
| 65 |
+
## Text Classification API
|
| 66 |
+
Upon input to this app, It will show the sentiment of the text (positive, negative, or neutral).
|
| 67 |
+
Check out the docs for the `/analyze/{text}` endpoint below to try it out!
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
# Initialize the FastAPI app
|
| 71 |
+
app = FastAPI(lifespan=lifespan, docs_url="/", description=description)
|
| 72 |
|
| 73 |
# Define the input data model
|
| 74 |
class TextInput(BaseModel):
|
|
|
|
| 77 |
# Define the welcome endpoint
|
| 78 |
@app.get('/')
|
| 79 |
async def welcome():
|
| 80 |
+
return "Welcome to our First Emotion Classification API"
|
| 81 |
|
| 82 |
# Validate input text length
|
| 83 |
MAX_TEXT_LENGTH = 1000
|