codeby-hp commited on
Commit
901117e
·
verified ·
1 Parent(s): 00a6941

Upload 5 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements first for better caching
11
+ COPY fastapi_app/requirements.txt /app/
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application files
15
+ COPY fastapi_app /app/
16
+
17
+ # Create non-root user
18
+ RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
19
+ USER appuser
20
+
21
+ EXPOSE 8000
22
+
23
+ # Health check using urllib (built-in)
24
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
25
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
26
+
27
+ # Run the application
28
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
fastapi_app/app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from get_model import download_model_from_s3
4
+ from contextlib import asynccontextmanager
5
+ from fastapi import FastAPI, Request, Form
6
+ from fastapi.responses import HTMLResponse
7
+ from fastapi.templating import Jinja2Templates
8
+ from fastapi.staticfiles import StaticFiles
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ model = None
15
+ tokenizer = None
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ """Load model on startup and cleanup on shutdown"""
22
+ global model, tokenizer
23
+
24
+ try:
25
+ logger.info("Starting model download from S3...")
26
+ model_dir = download_model_from_s3(local_dir="./model")
27
+
28
+ logger.info("Loading tokenizer...")
29
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
30
+
31
+ logger.info("Loading model...")
32
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
33
+ model.to(device)
34
+ model.eval()
35
+
36
+ logger.info(f"Model loaded successfully on {device}")
37
+ except Exception as e:
38
+ logger.error(f"Error loading model: {e}")
39
+ raise
40
+
41
+ yield
42
+
43
+ logger.info("Shutting down...")
44
+
45
+
46
+ app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan)
47
+
48
+ templates = Jinja2Templates(directory="templates")
49
+
50
+
51
+ @app.get("/", response_class=HTMLResponse)
52
+ async def home(request: Request):
53
+ """Render the home page"""
54
+ return templates.TemplateResponse("index.html", {"request": request})
55
+
56
+
57
+ @app.post("/predict")
58
+ async def predict(request: Request, text: str = Form(...)):
59
+ """Predict sentiment for the given text"""
60
+ if not text.strip():
61
+ return templates.TemplateResponse(
62
+ "index.html",
63
+ {"request": request, "error": "Please enter some text to analyze"},
64
+ )
65
+
66
+ try:
67
+ inputs = tokenizer(
68
+ text, return_tensors="pt", truncation=True, max_length=512, padding=True
69
+ )
70
+ inputs = {k: v.to(device) for k, v in inputs.items()}
71
+
72
+ with torch.no_grad():
73
+ outputs = model(**inputs)
74
+ logits = outputs.logits
75
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
76
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
77
+ confidence = probabilities[0][predicted_class].item()
78
+
79
+ sentiment_map = {0: "Negative", 1: "Positive"}
80
+ sentiment = sentiment_map.get(predicted_class, "Unknown")
81
+
82
+ return templates.TemplateResponse(
83
+ "index.html",
84
+ {
85
+ "request": request,
86
+ "text": text,
87
+ "sentiment": sentiment,
88
+ "confidence": round(confidence * 100, 2),
89
+ },
90
+ )
91
+
92
+ except Exception as e:
93
+ logger.error(f"Prediction error: {e}")
94
+ return templates.TemplateResponse(
95
+ "index.html", {"request": request, "error": f"An error occurred: {str(e)}"}
96
+ )
97
+
98
+
99
+ @app.get("/health")
100
+ async def health_check():
101
+ """Health check endpoint"""
102
+ return {
103
+ "status": "healthy",
104
+ "model_loaded": model is not None,
105
+ "device": str(device),
106
+ }
107
+
108
+
109
+ if __name__ == "__main__":
110
+ import uvicorn
111
+
112
+ uvicorn.run(app, host="0.0.0.0", port=8000)
fastapi_app/get_model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import logging
3
+
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def download_model_from_s3(
14
+ local_dir="./model", s3_prefix="ml-models/tinybert-sentiment-analysis"
15
+ ):
16
+ """
17
+ Download the fine-tuned model from S3 bucket
18
+ """
19
+ bucket_name = os.getenv("BUCKET_NAME")
20
+ if not bucket_name:
21
+ raise ValueError("BUCKET_NAME not found in .env file")
22
+
23
+ os.makedirs(local_dir, exist_ok=True)
24
+
25
+ s3_client = boto3.client("s3")
26
+
27
+ model_files = [
28
+ "config.json",
29
+ "model.safetensors",
30
+ "special_tokens_map.json",
31
+ "tokenizer_config.json",
32
+ "tokenizer.json",
33
+ "vocab.txt",
34
+ ]
35
+
36
+ logger.info(f"Downloading model from S3 bucket: {bucket_name}/{s3_prefix}")
37
+
38
+ for file_name in model_files:
39
+ try:
40
+ local_file_path = os.path.join(local_dir, file_name)
41
+
42
+ if os.path.exists(local_file_path):
43
+ logger.info(f"File {file_name} already exists, skipping...")
44
+ continue
45
+
46
+ s3_key = f"{s3_prefix}/{file_name}" if s3_prefix else file_name
47
+
48
+ logger.info(f"Downloading {s3_key}...")
49
+ s3_client.download_file(bucket_name, s3_key, local_file_path)
50
+ logger.info(f"Successfully downloaded {file_name}")
51
+ except Exception as e:
52
+ logger.error(f"Error downloading {file_name}: {e}")
53
+ raise
54
+
55
+ logger.info("Model download completed successfully")
56
+ return local_dir
fastapi_app/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.6
2
+ jinja2==3.1.5
3
+ boto3==1.34.149
4
+ python-dotenv==1.0.0
5
+ transformers==4.43.3
6
+ torch==2.3.1
7
+ uvicorn==0.34.0
8
+ python-multipart==0.0.18
fastapi_app/templates/index.html ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Sentiment Analysis - TinyBERT</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ </head>
9
+ <body class="bg-gradient-to-br from-slate-50 to-slate-100 min-h-screen">
10
+ <div class="container mx-auto px-4 py-12 max-w-3xl">
11
+ <!-- Header -->
12
+ <div class="text-center mb-12">
13
+ <h1 class="text-4xl font-bold text-slate-800 mb-3">
14
+ Sentiment Analysis
15
+ </h1>
16
+ <p class="text-slate-600">
17
+ Powered by Fine-tuned TinyBERT
18
+ </p>
19
+ </div>
20
+
21
+ <!-- Main Card -->
22
+ <div class="bg-white rounded-2xl shadow-xl p-8 mb-6">
23
+ <form method="post" action="/predict" class="space-y-6">
24
+ <!-- Text Input -->
25
+ <div>
26
+ <label for="text" class="block text-sm font-medium text-slate-700 mb-2">
27
+ Enter text to analyze
28
+ </label>
29
+ <textarea
30
+ id="text"
31
+ name="text"
32
+ rows="5"
33
+ class="w-full px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition resize-none"
34
+ placeholder="Type or paste your text here..."
35
+ required
36
+ >{% if text %}{{ text }}{% endif %}</textarea>
37
+ </div>
38
+
39
+ <!-- Submit Button -->
40
+ <button
41
+ type="submit"
42
+ class="w-full bg-blue-600 hover:bg-blue-700 text-white font-medium py-3 px-6 rounded-lg transition duration-200 shadow-md hover:shadow-lg"
43
+ >
44
+ Analyze Sentiment
45
+ </button>
46
+ </form>
47
+
48
+ <!-- Error Message -->
49
+ {% if error %}
50
+ <div class="mt-6 bg-red-50 border-l-4 border-red-500 p-4 rounded">
51
+ <p class="text-red-700">{{ error }}</p>
52
+ </div>
53
+ {% endif %}
54
+
55
+ <!-- Results -->
56
+ {% if sentiment %}
57
+ <div class="mt-8 border-t pt-6">
58
+ <h2 class="text-xl font-semibold text-slate-800 mb-4">Results</h2>
59
+
60
+ <div class="grid grid-cols-2 gap-4">
61
+ <!-- Sentiment -->
62
+ <div class="bg-slate-50 rounded-lg p-4">
63
+ <p class="text-sm text-slate-600 mb-1">Sentiment</p>
64
+ <div class="flex items-center">
65
+ <span class="text-2xl mr-2">
66
+ {% if sentiment == "Positive" %}
67
+ 😊
68
+ {% else %}
69
+ 😔
70
+ {% endif %}
71
+ </span>
72
+ <p class="text-2xl font-bold {% if sentiment == 'Positive' %}text-green-600{% else %}text-red-600{% endif %}">
73
+ {{ sentiment }}
74
+ </p>
75
+ </div>
76
+ </div>
77
+
78
+ <!-- Confidence -->
79
+ <div class="bg-slate-50 rounded-lg p-4">
80
+ <p class="text-sm text-slate-600 mb-1">Confidence</p>
81
+ <p class="text-2xl font-bold text-blue-600">
82
+ {{ confidence }}%
83
+ </p>
84
+ </div>
85
+ </div>
86
+
87
+ <!-- Confidence Bar -->
88
+ <div class="mt-4">
89
+ <div class="w-full bg-slate-200 rounded-full h-3 overflow-hidden">
90
+ <div
91
+ class="h-full {% if sentiment == 'Positive' %}bg-green-500{% else %}bg-red-500{% endif %} transition-all duration-500"
92
+ style="width: {{ confidence }}%"
93
+ ></div>
94
+ </div>
95
+ </div>
96
+ </div>
97
+ {% endif %}
98
+ </div>
99
+
100
+ <!-- Footer -->
101
+ <div class="text-center text-slate-500 text-sm">
102
+ <p>Fine-tuned TinyBERT model for sentiment classification</p>
103
+ </div>
104
+ </div>
105
+ </body>
106
+ </html>