Yousuf-Islam commited on
Commit
e919ef3
·
verified ·
1 Parent(s): ca465d9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -22
main.py CHANGED
@@ -4,57 +4,58 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image
5
  import io
6
  from torchvision import transforms
7
-
8
- # Import the loader from the file next to this one
9
  from model_loader import load_model
10
 
11
  app = FastAPI()
12
 
13
- # Enable CORS so React can talk to this
14
  app.add_middleware(
15
  CORSMiddleware,
16
- allow_origins=["*"],
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
 
 
 
22
  # --- LOAD MODEL ---
23
- print("Loading model...")
24
  try:
25
- # This expects the .pth file to be in the root folder
26
- model_wrapper = load_model("InceptionViT_best_model.pth")
27
- model = model_wrapper.model
28
- print("Model loaded successfully!")
29
  except Exception as e:
30
- print(f"CRITICAL ERROR LOADING MODEL: {e}")
31
- # If this fails, check the troubleshooting note below
32
 
33
  # --- TRANSFORM ---
34
- # Ensure these match your training exactly
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
38
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
39
  ])
40
 
41
  @app.get("/")
42
  def home():
43
- return {"message": "API is running"}
44
 
45
  @app.post("/predict")
46
  async def predict(file: UploadFile = File(...)):
47
- # 1. Read Image
 
 
48
  image_data = await file.read()
49
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
50
 
51
- # 2. Preprocess
52
- tensor = transform(image).unsqueeze(0) # Add batch dimension
53
 
54
- # 3. Predict
55
  with torch.no_grad():
56
- outputs = model(tensor)
57
- _, predicted = torch.max(outputs, 1)
 
58
 
59
- # 4. Return
60
- return {"result": str(predicted.item())}
 
 
 
4
  from PIL import Image
5
  import io
6
  from torchvision import transforms
 
 
7
  from model_loader import load_model
8
 
9
  app = FastAPI()
10
 
 
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
  allow_credentials=True,
15
  allow_methods=["*"],
16
  allow_headers=["*"],
17
  )
18
 
19
+ model = None
20
+ device = torch.device("cpu")
21
+
22
  # --- LOAD MODEL ---
23
+ print("--- STARTING SERVER ---")
24
  try:
25
+ model = load_model("InceptionViT_best_model.pth")
26
+ print("✅ Model loaded successfully!")
 
 
27
  except Exception as e:
28
+ print(f"CRITICAL ERROR: {e}")
 
29
 
30
  # --- TRANSFORM ---
31
+ # Matches your training code exactly
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
34
  transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
36
+ std=[0.229, 0.224, 0.225]),
37
  ])
38
 
39
  @app.get("/")
40
  def home():
41
+ return {"status": "Running"}
42
 
43
  @app.post("/predict")
44
  async def predict(file: UploadFile = File(...)):
45
+ if model is None:
46
+ return {"error": "Model not loaded"}
47
+
48
  image_data = await file.read()
49
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
50
 
51
+ tensor = transform(image).unsqueeze(0).to(device)
 
52
 
 
53
  with torch.no_grad():
54
+ logits = model(tensor)
55
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
56
+ confidence, predicted = torch.max(probabilities, 1)
57
 
58
+ return {
59
+ "prediction": str(predicted.item()),
60
+ "confidence": float(confidence.item())
61
+ }