saniaE commited on
Commit
2d7fbe6
·
1 Parent(s): 38bcc67

updated api

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -3,8 +3,21 @@ import io
3
  import numpy as np
4
  import tensorflow as tf
5
  from fastapi import FastAPI, File, UploadFile
 
6
  from PIL import Image
7
- from huggingface_hub import hf_hub_download, login
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Import Mask RCNN modules
10
  from mrcnn.config import Config
@@ -26,17 +39,15 @@ class PredictionConfig(Config):
26
  NUM_CLASSES = 1 + 1
27
  DETECTION_MIN_CONFIDENCE = 0.9
28
 
29
- # Configuration constants
30
- REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation"
31
- FILENAME = "mask_rcnn_petrol station_0080.h5"
32
  config = PredictionConfig()
33
  graph = tf.get_default_graph()
34
  model_eval = None
35
 
 
36
  @app.on_event("startup")
37
  def load_model():
38
  global model_eval
39
- token = os.getenv("HF_Token")
40
 
41
  if token:
42
  login(token=token)
@@ -47,10 +58,12 @@ def load_model():
47
 
48
  weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
49
 
50
- model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config)
51
- model_eval.load_weights(weights_path, by_name=True)
52
-
53
- model_eval.keras_model._make_predict_function()
 
 
54
  print("Model loaded successfully.")
55
 
56
 
 
3
  import numpy as np
4
  import tensorflow as tf
5
  from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.middleware.cors import CORSMiddleware
7
  from PIL import Image
8
+ from huggingface_hub import hf_hub_download, snapshot_download, login
9
+
10
+ # Source repo configs
11
+ REPO_ID = "SaniaE/MRCNN_Petrol_Pump_Segmentation"
12
+ FILENAME = "mask_rcnn_petrol station_0080.h5"
13
+ token = os.getenv("HF_Token")
14
+
15
+ snapshot_download(
16
+ repo_id=REPO_ID,
17
+ allow_patterns=["mrcnn/*"],
18
+ local_dir=".",
19
+ token=token
20
+ )
21
 
22
  # Import Mask RCNN modules
23
  from mrcnn.config import Config
 
39
  NUM_CLASSES = 1 + 1
40
  DETECTION_MIN_CONFIDENCE = 0.9
41
 
42
+ # Configuration constants
 
 
43
  config = PredictionConfig()
44
  graph = tf.get_default_graph()
45
  model_eval = None
46
 
47
+
48
  @app.on_event("startup")
49
  def load_model():
50
  global model_eval
 
51
 
52
  if token:
53
  login(token=token)
 
58
 
59
  weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
60
 
61
+ # Initialize model within the same graph context
62
+ with graph.as_default():
63
+ model_eval = modellib.MaskRCNN(mode="inference", model_dir=".", config=config)
64
+ model_eval.load_weights(weights_path, by_name=True)
65
+ model_eval.keras_model._make_predict_function()
66
+
67
  print("Model loaded successfully.")
68
 
69