Omkar1872 commited on
Commit
a813020
·
verified ·
1 Parent(s): d7a2701

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -4,15 +4,13 @@ from fastapi.templating import Jinja2Templates
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
7
- from torchvision import models, transforms
8
  from PIL import Image
9
  import io
10
- import torch.nn as nn
11
 
12
- # FastAPI app setup
13
  app = FastAPI()
14
 
15
- # CORS middleware (optional for frontend access)
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
@@ -28,33 +26,37 @@ templates = Jinja2Templates(directory="templates")
28
  # Device
29
  device = torch.device("cpu")
30
 
31
- # Model architecture
32
  model = models.resnet18(weights=None)
33
- model.fc = nn.Linear(model.fc.in_features, 2) # Binary classification
34
  model.load_state_dict(torch.load("pneumonia_weights.pth", map_location=device))
 
35
  model.eval()
36
 
37
- # Image transform
38
  transform = transforms.Compose([
39
  transforms.Resize((224, 224)),
40
  transforms.ToTensor(),
41
  ])
42
 
43
- # Homepage route
44
  @app.get("/", response_class=HTMLResponse)
45
  async def home(request: Request):
46
  return templates.TemplateResponse("index.html", {"request": request})
47
 
48
- # Prediction route
49
  @app.post("/predict")
50
  async def predict(file: UploadFile = File(...)):
51
- contents = await file.read()
52
- image = Image.open(io.BytesIO(contents)).convert("RGB")
53
- image = transform(image).unsqueeze(0).to(device)
 
54
 
55
- with torch.no_grad():
56
- output = model(image)
57
- prediction = torch.argmax(output, 1).item()
58
- result = "Pneumonia" if prediction == 1 else "Normal"
59
 
60
- return {"result": result}
 
 
 
 
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torch
7
+ from torchvision import transforms, models
8
  from PIL import Image
9
  import io
 
10
 
 
11
  app = FastAPI()
12
 
13
+ # CORS
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
26
  # Device
27
  device = torch.device("cpu")
28
 
29
+ # Model (assumes ResNet18 was used)
30
  model = models.resnet18(weights=None)
31
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # Assuming 2 classes: Normal & Pneumonia
32
  model.load_state_dict(torch.load("pneumonia_weights.pth", map_location=device))
33
+ model.to(device)
34
  model.eval()
35
 
36
+ # Transform
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
39
  transforms.ToTensor(),
40
  ])
41
 
42
+ # Routes
43
  @app.get("/", response_class=HTMLResponse)
44
  async def home(request: Request):
45
  return templates.TemplateResponse("index.html", {"request": request})
46
 
 
47
  @app.post("/predict")
48
  async def predict(file: UploadFile = File(...)):
49
+ try:
50
+ contents = await file.read()
51
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
52
+ image = transform(image).unsqueeze(0).to(device)
53
 
54
+ with torch.no_grad():
55
+ output = model(image)
56
+ predicted = torch.argmax(output, dim=1).item()
57
+ result = "Pneumonia" if predicted == 1 else "Normal"
58
 
59
+ return {"result": result}
60
+
61
+ except Exception as e:
62
+ return {"result": f"Error during prediction: {str(e)}"}