ihtesham0345 commited on
Commit
31082a9
·
0 Parent(s):

Initial commit with LFS

Browse files
Files changed (8) hide show
  1. .dockerignore +6 -0
  2. .gitattributes +35 -0
  3. .gitignore +6 -0
  4. Dockerfile +22 -0
  5. README.md +10 -0
  6. app.py +138 -0
  7. fruit_classifier_model.h5 +3 -0
  8. requirements.txt +6 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.ipynb
3
+ .ipynb_checkpoints
4
+ .git
5
+ .env
6
+ .DS_Store
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.ipynb
3
+ .ipynb_checkpoints/
4
+ .git/
5
+ .env
6
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set up a new user named "user" with user ID 1000
4
+ RUN useradd -m -u 1000 user
5
+
6
+ # Switch to the "user" user
7
+ USER user
8
+
9
+ # Set home to the user's home directory
10
+ ENV PATH="/home/user/.local/bin:$PATH"
11
+
12
+ WORKDIR /app
13
+
14
+ # Copy the current directory contents into the container at /app setting the owner to the user
15
+ COPY --chown=user requirements.txt requirements.txt
16
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
17
+
18
+ COPY --chown=user . /app
19
+
20
+ EXPOSE 7860
21
+
22
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Fruits API
3
+ emoji: 🌖
4
+ colorFrom: pink
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import base64 # Not needed
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
+ from fastapi.responses import StreamingResponse
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ import io
8
+ from typing import List
9
+
10
+ app = FastAPI(title="Fruit Classifier API")
11
+
12
+ # Load the model
13
+ # Assuming the model is in the same directory
14
+ MODEL_PATH = "fruit_classifier_model.h5"
15
+ try:
16
+ model = tf.keras.models.load_model(MODEL_PATH)
17
+ print(f"Model loaded successfully from {MODEL_PATH}")
18
+ except Exception as e:
19
+ print(f"Error loading model: {e}")
20
+ # We allow the app to start even if model fails, but predict will fail
21
+ model = None
22
+
23
+ # Class names extracted from the training notebook
24
+ CLASS_NAMES = [
25
+ 'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry',
26
+ 'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig',
27
+ 'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango',
28
+ 'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum',
29
+ 'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop',
30
+ 'Starfruit', 'Strawberry', 'Watermelon'
31
+ ]
32
+
33
+ def preprocess_image(image: Image.Image) -> np.ndarray:
34
+ """
35
+ Preprocess the image to match the model's expected input.
36
+ EfficientNet usually expects (224, 224, 3) and values in [0, 255]
37
+ if using the internal preprocessing layer, or pre-scaled if not.
38
+ The notebook showed:
39
+ tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...)
40
+ and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary).
41
+ The provided summary showed:
42
+ rescaling (Rescaling) ...
43
+ normalization (Normalization) ...
44
+ So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer).
45
+ """
46
+ if image.mode != "RGB":
47
+ image = image.convert("RGB")
48
+
49
+ image_resized = image.resize((224, 224))
50
+ image_array = np.array(image_resized)
51
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
52
+ return image_array
53
+
54
+ @app.get("/")
55
+ def read_root():
56
+ return {"message": "Welcome to the Fruit Classifier API"}
57
+
58
+ @app.post("/predict")
59
+ async def predict(file: UploadFile = File(...)):
60
+ if model is None:
61
+ raise HTTPException(status_code=500, detail="Model not loaded")
62
+
63
+ try:
64
+ contents = await file.read()
65
+ image = Image.open(io.BytesIO(contents))
66
+ processed_image = preprocess_image(image)
67
+
68
+ predictions = model.predict(processed_image)
69
+ predicted_class_index = np.argmax(predictions[0])
70
+ confidence = float(predictions[0][predicted_class_index])
71
+ predicted_class = CLASS_NAMES[predicted_class_index]
72
+
73
+ return {
74
+ "prediction": predicted_class,
75
+ "confidence": confidence,
76
+ "filename": file.filename
77
+ }
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=str(e))
80
+
81
+ @app.post("/predict_image")
82
+ async def predict_image(file: UploadFile = File(...)):
83
+ if model is None:
84
+ raise HTTPException(status_code=500, detail="Model not loaded")
85
+
86
+ try:
87
+ contents = await file.read()
88
+ original_image = Image.open(io.BytesIO(contents))
89
+
90
+ # Ensure RGB
91
+ if original_image.mode != "RGB":
92
+ original_image = original_image.convert("RGB")
93
+
94
+ processed_image = preprocess_image(original_image)
95
+
96
+ predictions = model.predict(processed_image)
97
+ predicted_class_index = np.argmax(predictions[0])
98
+ confidence = float(predictions[0][predicted_class_index])
99
+ predicted_class = CLASS_NAMES[predicted_class_index]
100
+
101
+ # Draw on the original image
102
+ draw = ImageDraw.Draw(original_image)
103
+
104
+ # Try to load a nice font, otherwise default
105
+ try:
106
+ # Try loading a system font (Windows usually has arial)
107
+ font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20))
108
+ except IOError:
109
+ font = ImageFont.load_default()
110
+
111
+ text = f"{predicted_class} ({confidence:.2f})"
112
+
113
+ # Calculate text position (top-left or centered-top)
114
+ text_position = (10, 10)
115
+
116
+ # Draw text with outline for better visibility
117
+ x, y = text_position
118
+ outline_color = "black"
119
+ text_color = "red"
120
+
121
+ draw.text((x-1, y-1), text, font=font, fill=outline_color)
122
+ draw.text((x+1, y-1), text, font=font, fill=outline_color)
123
+ draw.text((x-1, y+1), text, font=font, fill=outline_color)
124
+ draw.text((x+1, y+1), text, font=font, fill=outline_color)
125
+ draw.text(text_position, text, font=font, fill=text_color)
126
+
127
+ # Save to bytes
128
+ img_byte_arr = io.BytesIO()
129
+ original_image.save(img_byte_arr, format='JPEG')
130
+ img_byte_arr.seek(0)
131
+
132
+ return StreamingResponse(img_byte_arr, media_type="image/jpeg")
133
+
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=str(e))
136
+
137
+ if __name__ == "__main__":
138
+ uvicorn.run(app, host="0.0.0.0", port=7860)
fruit_classifier_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5ac2a906652b32c4f0852ce291d45ec14188ae57ce3205926bb797a8f7f03a3
3
+ size 17467592
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ pillow
5
+ numpy
6
+ tensorflow