jatin1233232 commited on
Commit
0bf332e
·
verified ·
1 Parent(s): 4b576ac

Update app/fruit_model.py

Browse files
Files changed (1) hide show
  1. app/fruit_model.py +19 -13
app/fruit_model.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import tempfile
3
  import requests
 
4
  from PIL import Image
5
  import numpy as np
6
  import tensorflow as tf
@@ -9,28 +10,33 @@ from tensorflow.keras.models import load_model
9
  # Optional: Cache download
10
  os.environ["HF_HOME"] = "/tmp/cache"
11
 
12
- # Hugging Face model URL (.keras file)
13
  model_url = "https://huggingface.co/jatin1233232/Analiz_Fruit_Model/resolve/main/model.keras"
 
14
 
15
- # Download model to temp file
16
- def download_model(url):
17
  response = requests.get(url)
18
  if response.status_code == 200:
19
- tmp_model_path = os.path.join(tempfile.gettempdir(), "model.keras")
20
- with open(tmp_model_path, 'wb') as f:
21
  f.write(response.content)
22
- return tmp_model_path
23
  else:
24
- raise Exception("Failed to download model")
25
 
26
  # Load the Keras model
27
- model_path = download_model(model_url)
28
  model = load_model(model_path)
29
- print("✅ Model loaded successfully.")
30
 
31
- # Load class labels (optional — if not in model, you need to define them)
32
- # Example: labels = ['pizza', 'burger', 'idli', ...]
33
- labels = [f"Class {i}" for i in range(model.output_shape[-1])] # fallback dummy labels
 
 
 
 
34
 
35
  # Classification function
36
  def classify_fruit(image: Image.Image, target_size=(224, 224)):
@@ -45,4 +51,4 @@ def classify_fruit(image: Image.Image, target_size=(224, 224)):
45
  predicted_label = labels[predicted_index]
46
  confidence = float(np.max(predictions))
47
 
48
- return predicted_label, confidence
 
1
  import os
2
  import tempfile
3
  import requests
4
+ import json
5
  from PIL import Image
6
  import numpy as np
7
  import tensorflow as tf
 
10
  # Optional: Cache download
11
  os.environ["HF_HOME"] = "/tmp/cache"
12
 
13
+ # Hugging Face model and labels URLs
14
  model_url = "https://huggingface.co/jatin1233232/Analiz_Fruit_Model/resolve/main/model.keras"
15
+ labels_url = "https://huggingface.co/jatin1233232/Analiz_Fruit_Model/resolve/main/labels.json"
16
 
17
+ # Download file from Hugging Face
18
+ def download_file(url, filename):
19
  response = requests.get(url)
20
  if response.status_code == 200:
21
+ file_path = os.path.join(tempfile.gettempdir(), filename)
22
+ with open(file_path, 'wb') as f:
23
  f.write(response.content)
24
+ return file_path
25
  else:
26
+ raise Exception(f"Failed to download: {url}")
27
 
28
  # Load the Keras model
29
+ model_path = download_file(model_url, "model.keras")
30
  model = load_model(model_path)
31
+ print("✅ Fruit model loaded successfully.")
32
 
33
+ # Load labels from labels.json
34
+ labels_path = download_file(labels_url, "labels.json")
35
+ with open(labels_path, 'r') as f:
36
+ labels_dict = json.load(f)
37
+
38
+ # Convert to list of labels ordered by index
39
+ labels = [labels_dict[str(i)] for i in range(len(labels_dict))]
40
 
41
  # Classification function
42
  def classify_fruit(image: Image.Image, target_size=(224, 224)):
 
51
  predicted_label = labels[predicted_index]
52
  confidence = float(np.max(predictions))
53
 
54
+ return predicted_label, confidence