SID933 commited on
Commit
2a595c6
·
1 Parent(s): dfc01ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -8,21 +8,26 @@ import tensorflow as tf
8
 
9
  from typing import Annotated
10
  from fastapi import FastAPI, File, UploadFile
11
- # uvicorn main:app --reload
12
 
13
 
14
  def load_model() -> tf.keras.Model:
15
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
16
- 'model-resnet_custom_v3.h5')
17
- model = tf.keras.models.load_model(path)
18
- return model
 
 
19
 
20
 
21
  def load_labels() -> list[str]:
22
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
23
- 'tags.txt')
 
 
 
24
  with open(path) as f:
25
  labels = [line.strip() for line in f.readlines()]
 
26
  return labels
27
 
28
 
@@ -31,30 +36,36 @@ labels = load_labels()
31
  app = FastAPI()
32
 
33
 
34
- def predict(image: PIL.Image.Image,
35
- score_threshold: float) -> dict[str, float]:
 
 
36
  _, height, width, _ = model.input_shape
37
  image = np.asarray(image)
38
- image = tf.image.resize(image,
39
- size=(height, width),
40
- method=tf.image.ResizeMethod.AREA,
41
- preserve_aspect_ratio=True)
 
 
42
  image = image.numpy()
43
  image = dd.image.transform_and_pad_image(image, width, height)
44
  image = image / 255.
45
  probs = model.predict(image[None, ...])[0]
46
  probs = probs.astype(float)
47
  res = dict()
 
48
  for prob, label in zip(probs.tolist(), labels):
49
  if prob < score_threshold:
50
  continue
51
  res[label] = prob
 
52
  return res
53
 
54
 
55
  @app.get("/")
56
  async def root():
57
- return {"message": "Hello World"}
58
 
59
 
60
  @app.post("/upload/")
 
8
 
9
  from typing import Annotated
10
  from fastapi import FastAPI, File, UploadFile
 
11
 
12
 
13
  def load_model() -> tf.keras.Model:
14
+ path = huggingface_hub.hf_hub_download(
15
+ 'public-data/DeepDanbooru',
16
+ 'model-resnet_custom_v3.h5'
17
+ )
18
+
19
+ return tf.keras.models.load_model(path)
20
 
21
 
22
  def load_labels() -> list[str]:
23
+ path = huggingface_hub.hf_hub_download(
24
+ 'public-data/DeepDanbooru',
25
+ 'tags.txt'
26
+ )
27
+
28
  with open(path) as f:
29
  labels = [line.strip() for line in f.readlines()]
30
+
31
  return labels
32
 
33
 
 
36
  app = FastAPI()
37
 
38
 
39
+ def predict(
40
+ image: PIL.Image.Image,
41
+ score_threshold: float
42
+ ) -> dict[str, float]:
43
  _, height, width, _ = model.input_shape
44
  image = np.asarray(image)
45
+ image = tf.image.resize(
46
+ image,
47
+ size=(height, width),
48
+ method=tf.image.ResizeMethod.AREA,
49
+ preserve_aspect_ratio=True
50
+ )
51
  image = image.numpy()
52
  image = dd.image.transform_and_pad_image(image, width, height)
53
  image = image / 255.
54
  probs = model.predict(image[None, ...])[0]
55
  probs = probs.astype(float)
56
  res = dict()
57
+
58
  for prob, label in zip(probs.tolist(), labels):
59
  if prob < score_threshold:
60
  continue
61
  res[label] = prob
62
+
63
  return res
64
 
65
 
66
  @app.get("/")
67
  async def root():
68
+ return {"message": "Application Has Been Running!!"}
69
 
70
 
71
  @app.post("/upload/")