saifisvibinn commited on
Commit
6f2b9f4
·
1 Parent(s): ac19fe4

Replace user registration API with lung cancer prediction API

Browse files
Files changed (8) hide show
  1. Dockerfile +3 -0
  2. FASTAPI_README.md +0 -195
  3. README.md +49 -15
  4. best_lung_cancer_model.joblib +0 -0
  5. main.py +322 -175
  6. model_loader.py +141 -0
  7. requirements.txt +5 -1
  8. scaler.joblib +0 -0
Dockerfile CHANGED
@@ -12,6 +12,9 @@ RUN pip install --no-cache-dir -r requirements.txt
12
 
13
  # Copy application code
14
  COPY main.py .
 
 
 
15
  COPY start.sh .
16
 
17
  # Make startup script executable
 
12
 
13
  # Copy application code
14
  COPY main.py .
15
+ COPY model_loader.py .
16
+ COPY best_lung_cancer_model.joblib .
17
+ COPY scaler.joblib .
18
  COPY start.sh .
19
 
20
  # Make startup script executable
FASTAPI_README.md DELETED
@@ -1,195 +0,0 @@
1
- # FastAPI Backend - User Registration API
2
-
3
- A clean, production-ready FastAPI backend with user registration functionality.
4
-
5
- ## Features
6
-
7
- - ✅ RESTful API endpoints
8
- - ✅ Automatic Swagger/OpenAPI documentation
9
- - ✅ Pydantic models for request validation
10
- - ✅ Age validation (18+)
11
- - ✅ Clean, readable, production-ready code
12
- - ✅ Comprehensive error handling
13
- - ✅ Type hints throughout
14
-
15
- ## Installation
16
-
17
- 1. **Install dependencies:**
18
- ```bash
19
- pip install -r requirements.txt
20
- ```
21
-
22
- ## Running the API
23
-
24
- ### Development Mode (with auto-reload)
25
- ```bash
26
- uvicorn main:app --reload
27
- ```
28
-
29
- ### Production Mode
30
- ```bash
31
- uvicorn main:app --host 0.0.0.0 --port 8000
32
- ```
33
-
34
- The API will be available at:
35
- - **API**: http://localhost:8000
36
- - **Swagger UI**: http://localhost:8000/docs
37
- - **ReDoc**: http://localhost:8000/redoc
38
-
39
- ## API Endpoints
40
-
41
- ### 1. GET /status
42
- Check API status.
43
-
44
- **Response:**
45
- ```json
46
- {
47
- "status": "API is running"
48
- }
49
- ```
50
-
51
- **Example:**
52
- ```bash
53
- curl http://localhost:8000/status
54
- ```
55
-
56
- ### 2. POST /register
57
- Register a new user.
58
-
59
- **Request Body:**
60
- ```json
61
- {
62
- "name": "John Doe",
63
- "email": "john.doe@example.com",
64
- "age": 25
65
- }
66
- ```
67
-
68
- **Success Response (201 Created):**
69
- ```json
70
- {
71
- "success": true,
72
- "message": "User registered successfully",
73
- "user": {
74
- "name": "John Doe",
75
- "email": "john.doe@example.com",
76
- "age": 25
77
- }
78
- }
79
- ```
80
-
81
- **Error Response (400 Bad Request):**
82
- ```json
83
- {
84
- "success": false,
85
- "error": "User must be at least 18",
86
- "status_code": 400
87
- }
88
- ```
89
-
90
- **Example:**
91
- ```bash
92
- curl -X POST http://localhost:8000/register \
93
- -H "Content-Type: application/json" \
94
- -d '{
95
- "name": "John Doe",
96
- "email": "john.doe@example.com",
97
- "age": 25
98
- }'
99
- ```
100
-
101
- ## Validation Rules
102
-
103
- - **name**: Required, 1-100 characters, cannot be empty
104
- - **email**: Required, must be valid email format
105
- - **age**: Required, must be 18 or older
106
-
107
- ## API Documentation
108
-
109
- FastAPI automatically generates interactive API documentation:
110
-
111
- - **Swagger UI**: http://localhost:8000/docs
112
- - **ReDoc**: http://localhost:8000/redoc
113
-
114
- ## Testing
115
-
116
- ### Test with cURL
117
-
118
- **Status endpoint:**
119
- ```bash
120
- curl http://localhost:8000/status
121
- ```
122
-
123
- **Register endpoint (valid):**
124
- ```bash
125
- curl -X POST http://localhost:8000/register \
126
- -H "Content-Type: application/json" \
127
- -d '{"name": "Jane Smith", "email": "jane@example.com", "age": 25}'
128
- ```
129
-
130
- **Register endpoint (invalid age):**
131
- ```bash
132
- curl -X POST http://localhost:8000/register \
133
- -H "Content-Type: application/json" \
134
- -d '{"name": "Young User", "email": "young@example.com", "age": 16}'
135
- ```
136
-
137
- ### Test with Python
138
-
139
- ```python
140
- import requests
141
-
142
- # Test status endpoint
143
- response = requests.get("http://localhost:8000/status")
144
- print(response.json())
145
-
146
- # Test register endpoint
147
- response = requests.post(
148
- "http://localhost:8000/register",
149
- json={
150
- "name": "John Doe",
151
- "email": "john@example.com",
152
- "age": 25
153
- }
154
- )
155
- print(response.json())
156
- ```
157
-
158
- ## Project Structure
159
-
160
- ```
161
- .
162
- ├── main.py # FastAPI application
163
- ├── requirements.txt # Python dependencies
164
- └── FASTAPI_README.md # This file
165
- ```
166
-
167
- ## Code Quality
168
-
169
- - ✅ Type hints throughout
170
- - ✅ Comprehensive docstrings
171
- - ✅ Pydantic models for validation
172
- - ✅ Proper HTTP status codes
173
- - ✅ Error handling
174
- - ✅ Clean, readable code structure
175
- - ✅ Production-ready patterns
176
-
177
- ## Next Steps
178
-
179
- To make this production-ready, consider adding:
180
-
181
- 1. **Database Integration**: Store users in a database (PostgreSQL, MongoDB, etc.)
182
- 2. **Authentication**: Add JWT or OAuth2 authentication
183
- 3. **Password Hashing**: If adding passwords, use bcrypt or similar
184
- 4. **Email Verification**: Send confirmation emails
185
- 5. **Rate Limiting**: Prevent abuse
186
- 6. **Logging**: Add structured logging
187
- 7. **Testing**: Add unit and integration tests
188
- 8. **Docker**: Containerize the application
189
- 9. **Environment Variables**: Use .env for configuration
190
- 10. **CORS**: Configure CORS if needed for frontend integration
191
-
192
- ## License
193
-
194
- This project is provided as-is for educational purposes.
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,6 +1,14 @@
1
- # FastAPI User Registration Backend
2
 
3
- A clean, production-ready FastAPI backend with user registration functionality.
 
 
 
 
 
 
 
 
4
 
5
  ## Quick Start
6
 
@@ -24,21 +32,47 @@ See `QUICK_DEPLOY.md` for quick deployment instructions (Railway recommended).
24
 
25
  For detailed deployment options, see `DEPLOYMENT.md`.
26
 
27
- ## Documentation
28
-
29
- - **Quick Start Guide:** `FASTAPI_README.md`
30
- - **Deployment Guide:** `DEPLOYMENT.md`
31
- - **Quick Deploy:** `QUICK_DEPLOY.md`
32
-
33
  ## API Endpoints
34
 
 
35
  - `GET /status` - Check API status
36
- - `POST /register` - Register a new user (requires name, email, age 18+)
37
 
38
- ## Features
39
 
40
- - ✅ RESTful API endpoints
41
- - ✅ Automatic Swagger/OpenAPI documentation
42
- - ✅ Pydantic models for request validation
43
- - ✅ Age validation (18+)
44
- - ✅ Clean, production-ready code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lung Cancer Prediction API
2
 
3
+ A FastAPI-based REST API for predicting lung cancer risk based on patient symptoms and characteristics.
4
+
5
+ ## Features
6
+
7
+ - ✅ RESTful API endpoints
8
+ - ✅ Automatic Swagger/OpenAPI documentation
9
+ - ✅ Pydantic models for request validation
10
+ - ✅ CORS support for web applications
11
+ - ✅ Production-ready with error handling
12
 
13
  ## Quick Start
14
 
 
32
 
33
  For detailed deployment options, see `DEPLOYMENT.md`.
34
 
 
 
 
 
 
 
35
  ## API Endpoints
36
 
37
+ - `GET /` - API information
38
  - `GET /status` - Check API status
39
+ - `POST /predict` - Predict lung cancer risk
40
 
41
+ ## Request Format
42
 
43
+ ```json
44
+ {
45
+ "gender": "M",
46
+ "age": 65,
47
+ "smoking": "YES",
48
+ "yellow_fingers": "NO",
49
+ "anxiety": "NO",
50
+ "peer_pressure": "NO",
51
+ "chronic_disease": "YES",
52
+ "fatigue": "YES",
53
+ "allergy": "NO",
54
+ "wheezing": "YES",
55
+ "alcohol": "NO",
56
+ "coughing": "YES",
57
+ "shortness_of_breath": "YES",
58
+ "swallowing_difficulty": "NO",
59
+ "chest_pain": "YES"
60
+ }
61
+ ```
62
+
63
+ ## Response Format
64
+
65
+ ```json
66
+ {
67
+ "success": true,
68
+ "prediction": "YES",
69
+ "probability": 87.5,
70
+ "message": "Prediction: YES (Confidence: 87.50%)"
71
+ }
72
+ ```
73
+
74
+ ## Notes
75
+
76
+ - This application is for educational/research purposes only
77
+ - Medical predictions should always be verified by healthcare professionals
78
+ - The model accuracy depends on the quality of the training data
best_lung_cancer_model.joblib ADDED
Binary file (59.7 kB). View file
 
main.py CHANGED
@@ -1,114 +1,226 @@
1
  """
2
- FastAPI Backend Application
3
- A simple REST API with user registration functionality.
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException, status, Request
7
  from fastapi.responses import JSONResponse
8
  from fastapi.exceptions import RequestValidationError
9
- from pydantic import BaseModel, EmailStr, Field, field_validator
10
- import uvicorn
 
 
 
11
  import os
 
 
 
12
 
13
  # Initialize FastAPI application
14
- # This automatically enables Swagger UI at /docs and ReDoc at /redoc
15
  app = FastAPI(
16
- title="User Registration API",
17
- description="A simple API for user registration with validation",
18
  version="1.0.0",
19
- docs_url="/docs", # Swagger UI documentation
20
- redoc_url="/redoc" # Alternative API documentation
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # ============================================================================
25
  # Pydantic Models for Request/Response Validation
26
  # ============================================================================
27
 
28
- class RegisterRequest(BaseModel):
29
  """
30
- Request model for user registration.
31
- Validates name, email, and age fields.
32
  """
33
- name: str = Field(
34
- ...,
35
- min_length=1,
36
- max_length=100,
37
- description="User's full name",
38
- examples=["John Doe"]
39
- )
40
- email: EmailStr = Field(
41
- ...,
42
- description="User's email address",
43
- examples=["john.doe@example.com"]
44
- )
45
- age: int = Field(
46
- ...,
47
- description="User's age (must be 18 or older)",
48
- examples=[25]
49
- )
50
 
51
- @field_validator('name')
52
  @classmethod
53
- def validate_name(cls, v: str) -> str:
54
- """Validate that name is not empty after stripping whitespace."""
55
- if not v.strip():
56
- raise ValueError("Name cannot be empty")
57
- return v.strip()
 
58
 
59
- @field_validator('age')
 
 
 
60
  @classmethod
61
- def validate_age(cls, v: int) -> int:
62
- """Validate that age is 18 or older."""
63
- if v < 18:
64
- raise ValueError("User must be at least 18")
 
65
  return v
66
-
67
- class Config:
68
- """Pydantic configuration."""
69
- json_schema_extra = {
70
- "example": {
71
- "name": "John Doe",
72
- "email": "john.doe@example.com",
73
- "age": 25
74
- }
75
- }
76
 
77
 
78
- class RegisterResponse(BaseModel):
79
  """
80
- Response model for successful registration.
81
  """
82
- success: bool = Field(
83
- ...,
84
- description="Indicates if registration was successful",
85
- examples=[True]
86
- )
87
- message: str = Field(
88
- ...,
89
- description="Confirmation message",
90
- examples=["User registered successfully"]
91
- )
92
- user: dict = Field(
93
- ...,
94
- description="Registered user information",
95
- examples=[{
96
- "name": "John Doe",
97
- "email": "john.doe@example.com",
98
- "age": 25
99
- }]
100
- )
101
 
102
 
103
  class StatusResponse(BaseModel):
104
  """
105
  Response model for status endpoint.
106
  """
107
- status: str = Field(
108
- ...,
109
- description="API status message",
110
- examples=["API is running"]
111
- )
112
 
113
 
114
  # ============================================================================
@@ -118,24 +230,19 @@ class StatusResponse(BaseModel):
118
  @app.get(
119
  "/",
120
  summary="API Root",
121
- description="Root endpoint with API information and available endpoints",
122
  tags=["Info"]
123
  )
124
  async def root():
125
- """
126
- Root endpoint that provides API information.
127
-
128
- Returns:
129
- dict: API information and available endpoints
130
- """
131
  return {
132
- "message": "Welcome to the User Registration API",
133
  "version": "1.0.0",
134
  "docs": "/docs",
135
  "redoc": "/redoc",
136
  "endpoints": {
137
  "GET /status": "Check API status",
138
- "POST /register": "Register a new user (requires name, email, age 18+)"
139
  }
140
  }
141
 
@@ -144,7 +251,7 @@ async def root():
144
  "/status",
145
  response_model=StatusResponse,
146
  summary="Check API Status",
147
- description="Returns the current status of the API",
148
  tags=["Health"]
149
  )
150
  async def get_status():
@@ -152,90 +259,154 @@ async def get_status():
152
  Health check endpoint.
153
 
154
  Returns:
155
- JSONResponse: Status message indicating the API is running
156
-
157
- Example Response:
158
- {
159
- "status": "API is running"
160
- }
161
  """
162
- return StatusResponse(status="API is running")
 
 
 
 
 
 
163
 
164
 
165
  @app.post(
166
- "/register",
167
- response_model=RegisterResponse,
168
- status_code=status.HTTP_201_CREATED,
169
- summary="Register a New User",
170
- description="Register a new user with name, email, and age. Age must be 18 or older.",
171
- tags=["Users"]
172
  )
173
- async def register_user(user_data: RegisterRequest):
174
  """
175
- Register a new user endpoint.
176
-
177
- This endpoint accepts user registration data and validates:
178
- - Name: Must be non-empty string (1-100 characters)
179
- - Email: Must be a valid email format
180
- - Age: Must be 18 or older
181
 
182
  Args:
183
- user_data (RegisterRequest): User registration data
184
 
185
  Returns:
186
- RegisterResponse: Success confirmation with user data
187
 
188
  Raises:
189
- HTTPException: 400 Bad Request if validation fails
190
- HTTPException: 422 Unprocessable Entity if request format is invalid
191
-
192
- Example Request:
193
- {
194
- "name": "John Doe",
195
- "email": "john.doe@example.com",
196
- "age": 25
197
- }
198
-
199
- Example Response:
200
- {
201
- "success": true,
202
- "message": "User registered successfully",
203
- "user": {
204
- "name": "John Doe",
205
- "email": "john.doe@example.com",
206
- "age": 25
207
- }
208
- }
209
  """
210
- # Age validation is handled by Pydantic field_validator
211
- # In a real application, you would:
212
- # 1. Check if email already exists in database
213
- # 2. Hash password if included
214
- # 3. Save user to database
215
- # 4. Send confirmation email
216
- # For now, we'll just return a success response
217
 
218
- return RegisterResponse(
219
- success=True,
220
- message="User registered successfully",
221
- user={
222
- "name": user_data.name,
223
- "email": user_data.email,
224
- "age": user_data.age
225
- }
226
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  # ============================================================================
230
- # Custom Exception Handlers
231
  # ============================================================================
232
 
233
  @app.exception_handler(HTTPException)
234
- async def http_exception_handler(request, exc: HTTPException):
235
- """
236
- Custom handler for HTTP exceptions.
237
- Returns consistent error response format.
238
- """
239
  return JSONResponse(
240
  status_code=exc.status_code,
241
  content={
@@ -248,29 +419,8 @@ async def http_exception_handler(request, exc: HTTPException):
248
 
249
  @app.exception_handler(RequestValidationError)
250
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
251
- """
252
- Custom handler for Pydantic validation errors.
253
- Converts validation errors to 400 Bad Request with custom message for age.
254
- """
255
  errors = exc.errors()
256
-
257
- # Check if the error is related to age validation
258
- for error in errors:
259
- error_loc = error.get("loc", [])
260
- error_msg = str(error.get("msg", ""))
261
-
262
- # Check if this is an age validation error
263
- if "age" in error_loc and ("User must be at least 18" in error_msg or "18" in error_msg):
264
- return JSONResponse(
265
- status_code=status.HTTP_400_BAD_REQUEST,
266
- content={
267
- "success": False,
268
- "error": "User must be at least 18",
269
- "status_code": 400
270
- }
271
- )
272
-
273
- # For other validation errors, return standard format
274
  return JSONResponse(
275
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
276
  content={
@@ -287,12 +437,10 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
287
  # ============================================================================
288
 
289
  if __name__ == "__main__":
290
- # Run the application using uvicorn
291
  # Get port from environment variable (for deployment) or default to 8000
292
  port = int(os.environ.get("PORT", 8000))
293
 
294
  # --reload enables auto-reload on code changes (development only)
295
- # In production, reload should be False
296
  reload = os.environ.get("ENVIRONMENT", "development") == "development"
297
 
298
  uvicorn.run(
@@ -301,4 +449,3 @@ if __name__ == "__main__":
301
  port=port,
302
  reload=reload
303
  )
304
-
 
1
  """
2
+ FastAPI Lung Cancer Prediction API
3
+ A RESTful API for predicting lung cancer risk based on patient symptoms and characteristics.
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException, status, Request
7
  from fastapi.responses import JSONResponse
8
  from fastapi.exceptions import RequestValidationError
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field, field_validator
11
+ import numpy as np
12
+ import sys
13
+ import warnings
14
  import os
15
+ import uvicorn
16
+
17
+ warnings.filterwarnings('ignore')
18
 
19
  # Initialize FastAPI application
 
20
  app = FastAPI(
21
+ title="Lung Cancer Prediction API",
22
+ description="A RESTful API for predicting lung cancer risk based on patient symptoms",
23
  version="1.0.0",
24
+ docs_url="/docs",
25
+ redoc_url="/redoc"
26
  )
27
 
28
+ # Enable CORS for all origins
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # ============================================================================
38
+ # Model Loading with Compatibility Handling
39
+ # ============================================================================
40
+
41
+ model = None
42
+ scaler = None
43
+
44
+ # Try to load using the robust loader
45
+ try:
46
+ import sklearn
47
+ print(f"scikit-learn version: {sklearn.__version__}")
48
+
49
+ # First, try aggressive patching - USE EuclideanDistance64 (not 32!)
50
+ try:
51
+ import sklearn.metrics._dist_metrics as dist_metrics
52
+
53
+ # Patch EuclideanDistance if missing - prioritize 64-bit version
54
+ if not hasattr(dist_metrics, 'EuclideanDistance'):
55
+ print("Attempting to patch EuclideanDistance...")
56
+
57
+ # Try option 1: Use EuclideanDistance64 (model uses 64-bit)
58
+ if hasattr(dist_metrics, 'EuclideanDistance64'):
59
+ EuclideanDistance64 = dist_metrics.EuclideanDistance64
60
+ dist_metrics.EuclideanDistance = EuclideanDistance64
61
+ setattr(dist_metrics, 'EuclideanDistance', EuclideanDistance64)
62
+
63
+ # Update in sys.modules - CRITICAL for unpickling
64
+ mod_name = 'sklearn.metrics._dist_metrics'
65
+ if mod_name in sys.modules:
66
+ setattr(sys.modules[mod_name], 'EuclideanDistance', EuclideanDistance64)
67
+
68
+ if hasattr(dist_metrics, '__dict__'):
69
+ dist_metrics.__dict__['EuclideanDistance'] = EuclideanDistance64
70
+
71
+ print("[OK] Patched EuclideanDistance using EuclideanDistance64")
72
+
73
+ # Fallback: Use EuclideanDistance32
74
+ elif hasattr(dist_metrics, 'EuclideanDistance32'):
75
+ EuclideanDistance32 = dist_metrics.EuclideanDistance32
76
+ dist_metrics.EuclideanDistance = EuclideanDistance32
77
+ setattr(dist_metrics, 'EuclideanDistance', EuclideanDistance32)
78
+
79
+ mod_name = 'sklearn.metrics._dist_metrics'
80
+ if mod_name in sys.modules:
81
+ setattr(sys.modules[mod_name], 'EuclideanDistance', EuclideanDistance32)
82
+
83
+ if hasattr(dist_metrics, '__dict__'):
84
+ dist_metrics.__dict__['EuclideanDistance'] = EuclideanDistance32
85
+
86
+ print("[OK] Patched EuclideanDistance using EuclideanDistance32")
87
+
88
+ # Ensure patch is in sys.modules
89
+ if 'sklearn.metrics._dist_metrics' in sys.modules and hasattr(dist_metrics, 'EuclideanDistance'):
90
+ if not hasattr(sys.modules['sklearn.metrics._dist_metrics'], 'EuclideanDistance'):
91
+ setattr(sys.modules['sklearn.metrics._dist_metrics'], 'EuclideanDistance', dist_metrics.EuclideanDistance)
92
+
93
+ except Exception as patch_error:
94
+ print(f"Warning: Could not apply pre-patch: {patch_error}")
95
+ import traceback
96
+ traceback.print_exc()
97
+
98
+ # Now try to load the model
99
+ try:
100
+ print("Loading model...")
101
+ import joblib
102
+
103
+ # Try standard loading first
104
+ try:
105
+ model = joblib.load('best_lung_cancer_model.joblib')
106
+ scaler = joblib.load('scaler.joblib')
107
+ print("[OK] Model and scaler loaded successfully!")
108
+ except (AttributeError, ModuleNotFoundError, KeyError) as e:
109
+ if 'EuclideanDistance' in str(e) or 'EuclideanDistance' in repr(e):
110
+ print("Compatibility issue detected. Trying alternative loading method...")
111
+
112
+ # Try using the model_loader
113
+ try:
114
+ from model_loader import load_sklearn_model_safe
115
+ model, scaler = load_sklearn_model_safe('best_lung_cancer_model.joblib', 'scaler.joblib')
116
+ print("[OK] Model and scaler loaded successfully using compatibility loader!")
117
+ except Exception as e2:
118
+ print(f"Compatibility loader also failed: {e2}")
119
+ raise e # Raise original error
120
+ else:
121
+ raise
122
+
123
+ # Print model info if available
124
+ if hasattr(model, 'feature_names_in_'):
125
+ print(f"Model expects {len(model.feature_names_in_)} features")
126
+ print(f"Features: {list(model.feature_names_in_)}")
127
+ if hasattr(model, 'classes_'):
128
+ print(f"Model classes: {model.classes_}")
129
+ if scaler and hasattr(scaler, 'n_features_in_'):
130
+ print(f"Scaler expects {scaler.n_features_in_} features")
131
+
132
+ except Exception as e:
133
+ error_msg = str(e)
134
+ print("\n" + "="*70)
135
+ print("MODEL LOADING ERROR")
136
+ print("="*70)
137
+ print(f"\nError: {error_msg}")
138
+ print("\nTroubleshooting steps:")
139
+ print("\n1. Try installing a compatible scikit-learn version:")
140
+ print(" pip uninstall scikit-learn")
141
+ print(" pip install scikit-learn==1.2.2")
142
+ print("\n2. If that doesn't work, try using Python 3.10 or 3.11")
143
+ print(" (Python 3.12 may have compatibility issues)")
144
+ print("\n3. Alternative: Install scikit-learn with pre-built wheels:")
145
+ print(" pip install --only-binary :all: scikit-learn==1.2.2")
146
+ print("\n4. Check that both model files exist:")
147
+ print(" - best_lung_cancer_model.joblib")
148
+ print(" - scaler.joblib")
149
+ print("="*70 + "\n")
150
+ import traceback
151
+ traceback.print_exc()
152
+ model = None
153
+ scaler = None
154
+
155
+ except Exception as e:
156
+ print(f"Critical error during initialization: {e}")
157
+ import traceback
158
+ traceback.print_exc()
159
+ model = None
160
+ scaler = None
161
+
162
 
163
  # ============================================================================
164
  # Pydantic Models for Request/Response Validation
165
  # ============================================================================
166
 
167
+ class PredictionRequest(BaseModel):
168
  """
169
+ Request model for lung cancer prediction.
 
170
  """
171
+ gender: str = Field(..., description="Patient gender", examples=["M"])
172
+ age: float = Field(..., ge=1, le=150, description="Patient age", examples=[65])
173
+ smoking: str = Field(..., description="Smoking status", examples=["YES"])
174
+ yellow_fingers: str = Field(..., description="Yellow fingers symptom", examples=["NO"])
175
+ anxiety: str = Field(..., description="Anxiety symptom", examples=["NO"])
176
+ peer_pressure: str = Field(..., description="Peer pressure", examples=["NO"])
177
+ chronic_disease: str = Field(..., description="Chronic disease", examples=["YES"])
178
+ fatigue: str = Field(..., description="Fatigue symptom", examples=["YES"])
179
+ allergy: str = Field(..., description="Allergy", examples=["NO"])
180
+ wheezing: str = Field(..., description="Wheezing symptom", examples=["YES"])
181
+ alcohol: str = Field(..., description="Alcohol consumption", examples=["NO"])
182
+ coughing: str = Field(..., description="Coughing symptom", examples=["YES"])
183
+ shortness_of_breath: str = Field(..., description="Shortness of breath", examples=["YES"])
184
+ swallowing_difficulty: str = Field(..., description="Swallowing difficulty", examples=["NO"])
185
+ chest_pain: str = Field(..., description="Chest pain symptom", examples=["YES"])
 
 
186
 
187
+ @field_validator('gender')
188
  @classmethod
189
+ def validate_gender(cls, v: str) -> str:
190
+ """Validate gender is M or F."""
191
+ v = v.upper()
192
+ if v not in ['M', 'F']:
193
+ raise ValueError('gender must be "M" or "F"')
194
+ return v
195
 
196
+ @field_validator('smoking', 'yellow_fingers', 'anxiety', 'peer_pressure',
197
+ 'chronic_disease', 'fatigue', 'allergy', 'wheezing',
198
+ 'alcohol', 'coughing', 'shortness_of_breath',
199
+ 'swallowing_difficulty', 'chest_pain')
200
  @classmethod
201
+ def validate_yes_no(cls, v: str) -> str:
202
+ """Validate YES/NO fields."""
203
+ v = v.upper()
204
+ if v not in ['YES', 'NO']:
205
+ raise ValueError('must be "YES" or "NO"')
206
  return v
 
 
 
 
 
 
 
 
 
 
207
 
208
 
209
+ class PredictionResponse(BaseModel):
210
  """
211
+ Response model for prediction.
212
  """
213
+ success: bool = Field(..., description="Indicates if prediction was successful")
214
+ prediction: str = Field(..., description="Prediction result: YES or NO")
215
+ probability: float = Field(..., description="Confidence percentage")
216
+ message: str = Field(..., description="Human-readable message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
  class StatusResponse(BaseModel):
220
  """
221
  Response model for status endpoint.
222
  """
223
+ status: str = Field(..., description="API status message")
 
 
 
 
224
 
225
 
226
  # ============================================================================
 
230
  @app.get(
231
  "/",
232
  summary="API Root",
233
+ description="Root endpoint with API information",
234
  tags=["Info"]
235
  )
236
  async def root():
237
+ """Root endpoint that provides API information."""
 
 
 
 
 
238
  return {
239
+ "message": "Welcome to the Lung Cancer Prediction API",
240
  "version": "1.0.0",
241
  "docs": "/docs",
242
  "redoc": "/redoc",
243
  "endpoints": {
244
  "GET /status": "Check API status",
245
+ "POST /predict": "Predict lung cancer risk"
246
  }
247
  }
248
 
 
251
  "/status",
252
  response_model=StatusResponse,
253
  summary="Check API Status",
254
+ description="Returns the current status of the API and model loading status",
255
  tags=["Health"]
256
  )
257
  async def get_status():
 
259
  Health check endpoint.
260
 
261
  Returns:
262
+ StatusResponse: Status message indicating if API and model are ready
 
 
 
 
 
263
  """
264
+ if model is None or scaler is None:
265
+ raise HTTPException(
266
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
267
+ detail="Model or scaler not loaded"
268
+ )
269
+
270
+ return StatusResponse(status="API is running and model is loaded")
271
 
272
 
273
  @app.post(
274
+ "/predict",
275
+ response_model=PredictionResponse,
276
+ summary="Predict Lung Cancer Risk",
277
+ description="Predict lung cancer risk based on patient symptoms and characteristics",
278
+ tags=["Prediction"]
 
279
  )
280
+ async def predict(data: PredictionRequest):
281
  """
282
+ Predict lung cancer risk based on patient data.
 
 
 
 
 
283
 
284
  Args:
285
+ data: PredictionRequest containing patient information
286
 
287
  Returns:
288
+ PredictionResponse: Prediction result with confidence score
289
 
290
  Raises:
291
+ HTTPException: 500 if model not loaded, 400 if validation fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  """
293
+ if model is None or scaler is None:
294
+ raise HTTPException(
295
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
296
+ detail="Model or scaler not loaded. Please check server logs for details."
297
+ )
 
 
298
 
299
+ try:
300
+ # Convert YES/NO to numeric (YES=2, NO=1)
301
+ smoking = 2 if data.smoking == 'YES' else 1
302
+ yellow_fingers = 2 if data.yellow_fingers == 'YES' else 1
303
+ anxiety = 2 if data.anxiety == 'YES' else 1
304
+ peer_pressure = 2 if data.peer_pressure == 'YES' else 1
305
+ chronic_disease = 2 if data.chronic_disease == 'YES' else 1
306
+ fatigue = 2 if data.fatigue == 'YES' else 1
307
+ allergy = 2 if data.allergy == 'YES' else 1
308
+ wheezing = 2 if data.wheezing == 'YES' else 1
309
+ alcohol = 2 if data.alcohol == 'YES' else 1
310
+ coughing = 2 if data.coughing == 'YES' else 1
311
+ shortness_of_breath = 2 if data.shortness_of_breath == 'YES' else 1
312
+ swallowing_difficulty = 2 if data.swallowing_difficulty == 'YES' else 1
313
+ chest_pain = 2 if data.chest_pain == 'YES' else 1
314
+
315
+ # Try different gender encodings
316
+ # Pattern 1: M=1, F=0 (binary)
317
+ gender_encoded = 1 if data.gender == 'M' else 0
318
+
319
+ # Create feature array
320
+ features_v1 = np.array([[
321
+ gender_encoded, # Gender: M=1, F=0
322
+ data.age,
323
+ smoking,
324
+ yellow_fingers,
325
+ anxiety,
326
+ peer_pressure,
327
+ chronic_disease,
328
+ fatigue,
329
+ allergy,
330
+ wheezing,
331
+ alcohol,
332
+ coughing,
333
+ shortness_of_breath,
334
+ swallowing_difficulty,
335
+ chest_pain
336
+ ]], dtype=np.float64)
337
+
338
+ # Try alternative: gender as M=2, F=1
339
+ gender_encoded_v2 = 2 if data.gender == 'M' else 1
340
+ features_v2 = np.array([[
341
+ gender_encoded_v2, # Gender: M=2, F=1
342
+ data.age,
343
+ smoking,
344
+ yellow_fingers,
345
+ anxiety,
346
+ peer_pressure,
347
+ chronic_disease,
348
+ fatigue,
349
+ allergy,
350
+ wheezing,
351
+ alcohol,
352
+ coughing,
353
+ shortness_of_breath,
354
+ swallowing_difficulty,
355
+ chest_pain
356
+ ]], dtype=np.float64)
357
+
358
+ # Try to make prediction with first encoding
359
+ try:
360
+ features_scaled = scaler.transform(features_v1)
361
+ prediction = model.predict(features_scaled)[0]
362
+ prediction_proba = model.predict_proba(features_scaled)[0]
363
+ except:
364
+ # If that fails, try second encoding
365
+ try:
366
+ features_scaled = scaler.transform(features_v2)
367
+ prediction = model.predict(features_scaled)[0]
368
+ prediction_proba = model.predict_proba(features_scaled)[0]
369
+ except Exception as e:
370
+ raise HTTPException(
371
+ status_code=status.HTTP_400_BAD_REQUEST,
372
+ detail=f"Error processing features: {str(e)}"
373
+ )
374
+
375
+ # Get probability and result
376
+ # Model classes are [0, 1] where 0=NO, 1=YES
377
+ if prediction == 1:
378
+ result = "YES"
379
+ probability = prediction_proba[1] * 100 if len(prediction_proba) > 1 else (1 - prediction_proba[0]) * 100
380
+ else:
381
+ result = "NO"
382
+ probability = prediction_proba[0] * 100
383
+
384
+ return PredictionResponse(
385
+ success=True,
386
+ prediction=result,
387
+ probability=round(probability, 2),
388
+ message=f'Prediction: {result} (Confidence: {probability:.2f}%)'
389
+ )
390
+
391
+ except HTTPException:
392
+ raise
393
+ except Exception as e:
394
+ import traceback
395
+ error_details = traceback.format_exc()
396
+ print(f"Prediction error: {error_details}")
397
+ raise HTTPException(
398
+ status_code=status.HTTP_400_BAD_REQUEST,
399
+ detail=f'Prediction failed: {str(e)}'
400
+ )
401
 
402
 
403
  # ============================================================================
404
+ # Exception Handlers
405
  # ============================================================================
406
 
407
  @app.exception_handler(HTTPException)
408
+ async def http_exception_handler(request: Request, exc: HTTPException):
409
+ """Custom handler for HTTP exceptions."""
 
 
 
410
  return JSONResponse(
411
  status_code=exc.status_code,
412
  content={
 
419
 
420
  @app.exception_handler(RequestValidationError)
421
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
422
+ """Custom handler for validation errors."""
 
 
 
423
  errors = exc.errors()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  return JSONResponse(
425
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
426
  content={
 
437
  # ============================================================================
438
 
439
  if __name__ == "__main__":
 
440
  # Get port from environment variable (for deployment) or default to 8000
441
  port = int(os.environ.get("PORT", 8000))
442
 
443
  # --reload enables auto-reload on code changes (development only)
 
444
  reload = os.environ.get("ENVIRONMENT", "development") == "development"
445
 
446
  uvicorn.run(
 
449
  port=port,
450
  reload=reload
451
  )
 
model_loader.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Robust model loader with compatibility fixes for scikit-learn version mismatches.
3
+ """
4
+ import joblib
5
+ import pickle
6
+ import sys
7
+ import warnings
8
+
9
+ class SklearnCompatibilityUnpickler(pickle.Unpickler):
10
+ """Custom unpickler that handles scikit-learn compatibility issues."""
11
+
12
+ def find_class(self, module, name):
13
+ # Handle EuclideanDistance compatibility issue
14
+ if module == 'sklearn.metrics._dist_metrics' and name == 'EuclideanDistance':
15
+ try:
16
+ # Try to import and patch the module
17
+ import sklearn.metrics._dist_metrics as dist_metrics
18
+
19
+ # Check if EuclideanDistance exists
20
+ if not hasattr(dist_metrics, 'EuclideanDistance'):
21
+ # Try to create it from available classes
22
+ if hasattr(dist_metrics, 'EuclideanDistance32'):
23
+ # Create a class that acts like EuclideanDistance
24
+ class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance32):
25
+ pass
26
+ dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
27
+ elif hasattr(dist_metrics, 'EuclideanDistance64'):
28
+ class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance64):
29
+ pass
30
+ dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
31
+ else:
32
+ # Last resort: try to find it in neighbors module
33
+ try:
34
+ from sklearn.neighbors._dist_metrics import EuclideanDistance as ED
35
+ dist_metrics.EuclideanDistance = ED
36
+ except:
37
+ # Create a minimal stub class
38
+ class EuclideanDistanceStub:
39
+ def __init__(self, *args, **kwargs):
40
+ pass
41
+ dist_metrics.EuclideanDistance = EuclideanDistanceStub
42
+
43
+ return getattr(dist_metrics, 'EuclideanDistance')
44
+ except Exception as e:
45
+ warnings.warn(f"Could not patch EuclideanDistance: {e}")
46
+ # Fallback: return a stub class
47
+ class EuclideanDistanceStub:
48
+ def __init__(self, *args, **kwargs):
49
+ pass
50
+ return EuclideanDistanceStub
51
+
52
+ # For all other classes, use default behavior
53
+ return super().find_class(module, name)
54
+
55
+
56
+ def load_model_with_compatibility(model_path):
57
+ """
58
+ Load a joblib model with compatibility fixes.
59
+
60
+ Args:
61
+ model_path: Path to the .joblib model file
62
+
63
+ Returns:
64
+ Loaded model object
65
+ """
66
+ try:
67
+ # First, try to patch the module before loading
68
+ try:
69
+ import sklearn.metrics._dist_metrics as dist_metrics
70
+ if not hasattr(dist_metrics, 'EuclideanDistance'):
71
+ if hasattr(dist_metrics, 'EuclideanDistance32'):
72
+ dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance32
73
+ elif hasattr(dist_metrics, 'EuclideanDistance64'):
74
+ dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance64
75
+ except:
76
+ pass
77
+
78
+ # Try standard loading first
79
+ try:
80
+ return joblib.load(model_path)
81
+ except (AttributeError, ModuleNotFoundError) as e:
82
+ if 'EuclideanDistance' in str(e):
83
+ # Try with custom unpickler
84
+ warnings.warn("Using compatibility mode to load model...")
85
+ try:
86
+ # Use joblib's internal file handling but with custom unpickler
87
+ import joblib.numpy_pickle
88
+
89
+ # Open the file
90
+ with open(model_path, 'rb') as f:
91
+ # Try to use joblib's format detection
92
+ unpickler = SklearnCompatibilityUnpickler(f)
93
+ try:
94
+ return unpickler.load()
95
+ except:
96
+ # If that doesn't work, try monkey-patching more aggressively
97
+ # Re-import after patching
98
+ import importlib
99
+ import sklearn.metrics._dist_metrics
100
+ importlib.reload(sklearn.metrics._dist_metrics)
101
+
102
+ # Patch again after reload
103
+ dist_metrics = sklearn.metrics._dist_metrics
104
+ if not hasattr(dist_metrics, 'EuclideanDistance'):
105
+ if hasattr(dist_metrics, 'EuclideanDistance32'):
106
+ # Create a proper alias
107
+ dist_metrics.EuclideanDistance = type('EuclideanDistance',
108
+ (dist_metrics.EuclideanDistance32,), {})
109
+
110
+ # Try loading again
111
+ return joblib.load(model_path)
112
+ except Exception as e2:
113
+ raise RuntimeError(f"Failed to load model even with compatibility mode: {e2}")
114
+ else:
115
+ raise
116
+ except Exception as e:
117
+ raise RuntimeError(f"Error loading model from {model_path}: {e}")
118
+
119
+
120
+ def load_sklearn_model_safe(model_path, scaler_path=None):
121
+ """
122
+ Safely load sklearn model and scaler with compatibility fixes.
123
+
124
+ Args:
125
+ model_path: Path to model .joblib file
126
+ scaler_path: Path to scaler .joblib file (optional)
127
+
128
+ Returns:
129
+ Tuple of (model, scaler) or (model, None) if scaler_path not provided
130
+ """
131
+ model = load_model_with_compatibility(model_path)
132
+ scaler = None
133
+
134
+ if scaler_path:
135
+ try:
136
+ scaler = load_model_with_compatibility(scaler_path)
137
+ except Exception as e:
138
+ warnings.warn(f"Could not load scaler: {e}")
139
+
140
+ return model, scaler
141
+
requirements.txt CHANGED
@@ -2,4 +2,8 @@
2
  fastapi>=0.104.0
3
  uvicorn[standard]>=0.24.0
4
  pydantic>=2.0.0
5
- email-validator>=2.0.0
 
 
 
 
 
2
  fastapi>=0.104.0
3
  uvicorn[standard]>=0.24.0
4
  pydantic>=2.0.0
5
+
6
+ # Machine Learning dependencies
7
+ scikit-learn>=1.2.0
8
+ joblib>=1.3.0
9
+ numpy>=1.24.0
scaler.joblib ADDED
Binary file (1.52 kB). View file