codeby-hp commited on
Commit
a94e340
·
verified ·
1 Parent(s): 901117e

Update fastapi_app/app.py

Browse files
Files changed (1) hide show
  1. fastapi_app/app.py +110 -112
fastapi_app/app.py CHANGED
@@ -1,112 +1,110 @@
1
- import torch
2
- import logging
3
- from get_model import download_model_from_s3
4
- from contextlib import asynccontextmanager
5
- from fastapi import FastAPI, Request, Form
6
- from fastapi.responses import HTMLResponse
7
- from fastapi.templating import Jinja2Templates
8
- from fastapi.staticfiles import StaticFiles
9
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- model = None
15
- tokenizer = None
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
-
19
- @asynccontextmanager
20
- async def lifespan(app: FastAPI):
21
- """Load model on startup and cleanup on shutdown"""
22
- global model, tokenizer
23
-
24
- try:
25
- logger.info("Starting model download from S3...")
26
- model_dir = download_model_from_s3(local_dir="./model")
27
-
28
- logger.info("Loading tokenizer...")
29
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
30
-
31
- logger.info("Loading model...")
32
- model = AutoModelForSequenceClassification.from_pretrained(model_dir)
33
- model.to(device)
34
- model.eval()
35
-
36
- logger.info(f"Model loaded successfully on {device}")
37
- except Exception as e:
38
- logger.error(f"Error loading model: {e}")
39
- raise
40
-
41
- yield
42
-
43
- logger.info("Shutting down...")
44
-
45
-
46
- app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan)
47
-
48
- templates = Jinja2Templates(directory="templates")
49
-
50
-
51
- @app.get("/", response_class=HTMLResponse)
52
- async def home(request: Request):
53
- """Render the home page"""
54
- return templates.TemplateResponse("index.html", {"request": request})
55
-
56
-
57
- @app.post("/predict")
58
- async def predict(request: Request, text: str = Form(...)):
59
- """Predict sentiment for the given text"""
60
- if not text.strip():
61
- return templates.TemplateResponse(
62
- "index.html",
63
- {"request": request, "error": "Please enter some text to analyze"},
64
- )
65
-
66
- try:
67
- inputs = tokenizer(
68
- text, return_tensors="pt", truncation=True, max_length=512, padding=True
69
- )
70
- inputs = {k: v.to(device) for k, v in inputs.items()}
71
-
72
- with torch.no_grad():
73
- outputs = model(**inputs)
74
- logits = outputs.logits
75
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
76
- predicted_class = torch.argmax(probabilities, dim=-1).item()
77
- confidence = probabilities[0][predicted_class].item()
78
-
79
- sentiment_map = {0: "Negative", 1: "Positive"}
80
- sentiment = sentiment_map.get(predicted_class, "Unknown")
81
-
82
- return templates.TemplateResponse(
83
- "index.html",
84
- {
85
- "request": request,
86
- "text": text,
87
- "sentiment": sentiment,
88
- "confidence": round(confidence * 100, 2),
89
- },
90
- )
91
-
92
- except Exception as e:
93
- logger.error(f"Prediction error: {e}")
94
- return templates.TemplateResponse(
95
- "index.html", {"request": request, "error": f"An error occurred: {str(e)}"}
96
- )
97
-
98
-
99
- @app.get("/health")
100
- async def health_check():
101
- """Health check endpoint"""
102
- return {
103
- "status": "healthy",
104
- "model_loaded": model is not None,
105
- "device": str(device),
106
- }
107
-
108
-
109
- if __name__ == "__main__":
110
- import uvicorn
111
-
112
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import torch
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, Request, Form
5
+ from fastapi.responses import HTMLResponse
6
+ from fastapi.templating import Jinja2Templates
7
+ from fastapi.staticfiles import StaticFiles
8
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ model = None
14
+ tokenizer = None
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ @asynccontextmanager
19
+ async def lifespan(app: FastAPI):
20
+ """Load model on startup and cleanup on shutdown"""
21
+ global model, tokenizer
22
+
23
+ try:
24
+ model_id = "codeby-hp/FinetuneTinybert-SentimentClassification"
25
+
26
+ logger.info(f"Loading tokenizer from {model_id}...")
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+
29
+ logger.info(f"Loading model from {model_id}...")
30
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ logger.info(f"Model loaded successfully on {device}")
35
+ except Exception as e:
36
+ logger.error(f"Error loading model: {e}")
37
+ raise
38
+
39
+ yield
40
+
41
+ logger.info("Shutting down...")
42
+
43
+
44
+ app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan)
45
+
46
+ templates = Jinja2Templates(directory="templates")
47
+
48
+
49
+ @app.get("/", response_class=HTMLResponse)
50
+ async def home(request: Request):
51
+ """Render the home page"""
52
+ return templates.TemplateResponse("index.html", {"request": request})
53
+
54
+
55
+ @app.post("/predict")
56
+ async def predict(request: Request, text: str = Form(...)):
57
+ """Predict sentiment for the given text"""
58
+ if not text.strip():
59
+ return templates.TemplateResponse(
60
+ "index.html",
61
+ {"request": request, "error": "Please enter some text to analyze"},
62
+ )
63
+
64
+ try:
65
+ inputs = tokenizer(
66
+ text, return_tensors="pt", truncation=True, max_length=512, padding=True
67
+ )
68
+ inputs = {k: v.to(device) for k, v in inputs.items()}
69
+
70
+ with torch.no_grad():
71
+ outputs = model(**inputs)
72
+ logits = outputs.logits
73
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
74
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
75
+ confidence = probabilities[0][predicted_class].item()
76
+
77
+ sentiment_map = {0: "Negative", 1: "Positive"}
78
+ sentiment = sentiment_map.get(predicted_class, "Unknown")
79
+
80
+ return templates.TemplateResponse(
81
+ "index.html",
82
+ {
83
+ "request": request,
84
+ "text": text,
85
+ "sentiment": sentiment,
86
+ "confidence": round(confidence * 100, 2),
87
+ },
88
+ )
89
+
90
+ except Exception as e:
91
+ logger.error(f"Prediction error: {e}")
92
+ return templates.TemplateResponse(
93
+ "index.html", {"request": request, "error": f"An error occurred: {str(e)}"}
94
+ )
95
+
96
+
97
+ @app.get("/health")
98
+ async def health_check():
99
+ """Health check endpoint"""
100
+ return {
101
+ "status": "healthy",
102
+ "model_loaded": model is not None,
103
+ "device": str(device),
104
+ }
105
+
106
+
107
+ if __name__ == "__main__":
108
+ import uvicorn
109
+
110
+ uvicorn.run(app, host="0.0.0.0", port=8000)