stshanks commited on
Commit
1067d57
·
verified ·
1 Parent(s): d305c5e

Update app.py

Browse files

Update predict_text() with print debugging.

Files changed (1) hide show
  1. app.py +10 -40
app.py CHANGED
@@ -28,16 +28,17 @@ def decode_prediction(prediction):
28
  Expects prediction to be a numpy array of shape (1, 78).
29
  It returns the drug name corresponding to the highest probability.
30
  """
31
- # Get the index of the highest probability class
32
- predicted_index = np.argmax(prediction, axis=-1)[0]
33
- # Return the corresponding drug name
34
- return CLASS_NAMES[predicted_index]
 
35
 
36
  # Function to preprocess the uploaded image
37
  def preprocess_image(image):
38
- image = image.convert("RGB") # Convert to grayscale
39
- image = image.resize((64, 64)) # Resize to match model input
40
- image = np.array(image) / 255.0 # Normalize
41
  image = np.expand_dims(image, axis=0) # Add batch dimension
42
  return image
43
 
@@ -51,48 +52,17 @@ def predict_text(image):
51
  segment_width = image.shape[1] // num_chars # Split image into equal parts
52
 
53
  def predict_text(image):
54
- processed_image = preprocess_image(image) # Make sure the image is preprocessed to (64, 64, 3)
55
  prediction = model.predict(processed_image)
56
 
57
  print("Model output shape:", prediction.shape) # Should be (1, 78)
58
- print("Model output values:", prediction) # Check the raw probabilities
59
 
60
  # Decode the prediction to get the drug name
61
  predicted_text = decode_prediction(prediction)
62
  return predicted_text
63
 
64
 
65
- import numpy as np
66
- import string
67
-
68
- # Define the possible characters in prescription handwriting
69
- CHARACTER_SET = string.ascii_letters + string.digits + " .,-/()"
70
-
71
- def decode_prediction(prediction):
72
- # Ensure prediction is iterable
73
- if len(prediction.shape) == 2: # (1, num_classes), meaning single character classification
74
- indices = np.argmax(prediction, axis=-1) # Pick the most likely character
75
- text = CHARACTER_SET[indices[0]] # Convert to actual character
76
- return text
77
-
78
- elif len(prediction.shape) == 3: # (1, sequence_length, num_classes), meaning character sequence classification
79
- prediction = prediction[0] # Remove batch dimension
80
- indices = np.argmax(prediction, axis=-1) # Get character indices at each step
81
-
82
- # Convert indices to characters while removing duplicates
83
- decoded_text = []
84
- prev_char = None
85
- for i in indices:
86
- if i != prev_char and i < len(CHARACTER_SET): # Avoid duplicate characters
87
- decoded_text.append(CHARACTER_SET[i])
88
- prev_char = i # Update previous character
89
-
90
- return "".join(decoded_text)
91
-
92
- else:
93
- return "Error: Unexpected output shape!"
94
-
95
-
96
  # Gradio UI
97
  interface = gr.Interface(
98
  fn=predict_text,
 
28
  Expects prediction to be a numpy array of shape (1, 78).
29
  It returns the drug name corresponding to the highest probability.
30
  """
31
+ if prediction.shape != (1, 78):
32
+ return "Error: Unexpected model output shape"
33
+
34
+ predicted_index = np.argmax(prediction, axis=-1)[0] # Get the index of the highest probability
35
+ return CLASS_NAMES[predicted_index] # Return the corresponding drug name
36
 
37
  # Function to preprocess the uploaded image
38
  def preprocess_image(image):
39
+ image = image.convert("RGB") # Ensure 3 channels
40
+ image = image.resize((64, 64)) # Match model input size
41
+ image = np.array(image) / 255.0 # Normalize to [0,1]
42
  image = np.expand_dims(image, axis=0) # Add batch dimension
43
  return image
44
 
 
52
  segment_width = image.shape[1] // num_chars # Split image into equal parts
53
 
54
  def predict_text(image):
55
+ processed_image = preprocess_image(image) # Ensure input is (64, 64, 3)
56
  prediction = model.predict(processed_image)
57
 
58
  print("Model output shape:", prediction.shape) # Should be (1, 78)
59
+ print("Model output values:", prediction) # Print raw probabilities
60
 
61
  # Decode the prediction to get the drug name
62
  predicted_text = decode_prediction(prediction)
63
  return predicted_text
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Gradio UI
67
  interface = gr.Interface(
68
  fn=predict_text,