Sadeep Sachintha commited on
Commit
eafecbb
·
1 Parent(s): d345116

feat: implement FastAPI service for Sinhala sentiment analysis with model integration and API tests

Browse files
Files changed (3) hide show
  1. app/main.py +11 -1
  2. app/model.py +11 -7
  3. tests/test_api.py +8 -4
app/main.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from app.model import predict_sentiment
4
  import logging
5
 
6
  logging.basicConfig(level=logging.INFO)
@@ -19,6 +19,16 @@ class SentimentResponse(BaseModel):
19
  label: str
20
  score: float
21
 
 
 
 
 
 
 
 
 
 
 
22
  @app.get("/")
23
  def read_root():
24
  return {"message": "Welcome to the Sinhala Sentiment Analysis API. Use POST /predict to analyze text."}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from app.model import predict_sentiment, load_model
4
  import logging
5
 
6
  logging.basicConfig(level=logging.INFO)
 
19
  label: str
20
  score: float
21
 
22
+ @app.on_event("startup")
23
+ async def startup_event():
24
+ """Load the model when the app starts."""
25
+ try:
26
+ load_model()
27
+ logger.info("Model loaded successfully on startup")
28
+ except Exception as e:
29
+ logger.error(f"Failed to load model on startup: {e}")
30
+ raise
31
+
32
  @app.get("/")
33
  def read_root():
34
  return {"message": "Welcome to the Sinhala Sentiment Analysis API. Use POST /predict to analyze text."}
app/model.py CHANGED
@@ -5,14 +5,18 @@ logger = logging.getLogger(__name__)
5
 
6
  # Using a robust Sinhala sentiment analysis model from Hugging Face
7
  MODEL_NAME = "keshan/sinhala-sentiment-analysis"
 
8
 
9
- try:
10
- logger.info(f"Loading model {MODEL_NAME}...")
11
- sentiment_pipeline = pipeline("sentiment-analysis", model=MODEL_NAME)
12
- logger.info("Model loaded successfully.")
13
- except Exception as e:
14
- logger.error(f"Error loading model: {e}")
15
- sentiment_pipeline = None
 
 
 
16
 
17
  def predict_sentiment(text: str):
18
  if not sentiment_pipeline:
 
5
 
6
  # Using a robust Sinhala sentiment analysis model from Hugging Face
7
  MODEL_NAME = "keshan/sinhala-sentiment-analysis"
8
+ sentiment_pipeline = None
9
 
10
+ def load_model():
11
+ global sentiment_pipeline
12
+ if sentiment_pipeline is None:
13
+ try:
14
+ logger.info(f"Loading model {MODEL_NAME}...")
15
+ sentiment_pipeline = pipeline("sentiment-analysis", model=MODEL_NAME)
16
+ logger.info("Model loaded successfully.")
17
+ except Exception as e:
18
+ logger.error(f"Error loading model: {e}")
19
+ raise e
20
 
21
  def predict_sentiment(text: str):
22
  if not sentiment_pipeline:
tests/test_api.py CHANGED
@@ -1,20 +1,24 @@
1
  from fastapi.testclient import TestClient
2
  from app.main import app
 
3
 
4
- client = TestClient(app)
 
 
 
5
 
6
- def test_read_root():
7
  response = client.get("/")
8
  assert response.status_code == 200
9
  assert "Welcome" in response.json()["message"]
10
 
11
- def test_predict_sentiment_positive():
12
  # "This is a very good creation." in Sinhala
13
  response = client.post("/predict", json={"text": "මෙය ඉතා හොඳ නිර්මාණයක්."})
14
  assert response.status_code == 200
15
  assert "label" in response.json()
16
  assert "score" in response.json()
17
 
18
- def test_predict_sentiment_empty():
19
  response = client.post("/predict", json={"text": ""})
20
  assert response.status_code == 400
 
1
  from fastapi.testclient import TestClient
2
  from app.main import app
3
+ import pytest
4
 
5
+ @pytest.fixture(scope="module")
6
+ def client():
7
+ with TestClient(app) as c:
8
+ yield c
9
 
10
+ def test_read_root(client):
11
  response = client.get("/")
12
  assert response.status_code == 200
13
  assert "Welcome" in response.json()["message"]
14
 
15
+ def test_predict_sentiment_positive(client):
16
  # "This is a very good creation." in Sinhala
17
  response = client.post("/predict", json={"text": "මෙය ඉතා හොඳ නිර්මාණයක්."})
18
  assert response.status_code == 200
19
  assert "label" in response.json()
20
  assert "score" in response.json()
21
 
22
+ def test_predict_sentiment_empty(client):
23
  response = client.post("/predict", json={"text": ""})
24
  assert response.status_code == 400