Aff77 commited on
Commit
297f772
·
verified ·
1 Parent(s): d036b64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -55,23 +55,28 @@ model = load_model()
55
  # --------------------------
56
  # Prediction Logic
57
  # --------------------------
 
58
  def decode_predictions(preds):
59
- """Convert model output to text using CTC decoding"""
60
  preds = preds.permute(1, 0, 2) # [B, W, C]
61
- _, pred_indices = preds.max(2)
 
62
 
63
  texts = []
64
  for pred in pred_indices:
65
- # CTC decoding: merge repeated and remove blank
66
  decoded = []
67
  prev_char = None
68
  for idx in pred:
69
- char = idx_to_char.get(idx.item(), '')
70
- if char != prev_char and char != '' and idx.item() != (VOCAB_SIZE - 1):
71
- decoded.append(char)
72
- prev_char = char
 
 
73
  texts.append(''.join(decoded))
74
- return texts[0] if len(texts) == 1 else texts
 
75
 
76
  def preprocess_image(image):
77
  """Convert input to model-compatible format"""
@@ -85,22 +90,50 @@ def preprocess_image(image):
85
 
86
  def predict(image):
87
  try:
88
- # Handle Gradio input types
 
 
 
89
  if isinstance(image, dict):
 
90
  image = image['image'] if 'image' in image else image['data']
 
 
91
  if not isinstance(image, Image.Image):
92
- image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Process and predict
95
- image_tensor = preprocess_image(image)
96
  with torch.no_grad():
97
  outputs = model(image_tensor)
 
98
  prediction = decode_predictions(outputs)
 
99
 
100
  return prediction
101
 
102
  except Exception as e:
103
- return f"Error: {str(e)}"
 
 
104
 
105
  # --------------------------
106
  # Gradio Interface
 
55
  # --------------------------
56
  # Prediction Logic
57
  # --------------------------
58
+
59
  def decode_predictions(preds):
60
+ """More robust CTC decoding"""
61
  preds = preds.permute(1, 0, 2) # [B, W, C]
62
+ preds = torch.softmax(preds, dim=2)
63
+ pred_indices = torch.argmax(preds, dim=2)
64
 
65
  texts = []
66
  for pred in pred_indices:
67
+ # Merge repeated and remove blank (VOCAB_SIZE-1)
68
  decoded = []
69
  prev_char = None
70
  for idx in pred:
71
+ char_idx = idx.item()
72
+ if char_idx < len(idx_to_char) and char_idx != (VOCAB_SIZE - 1):
73
+ char = idx_to_char[char_idx]
74
+ if char != prev_char:
75
+ decoded.append(char)
76
+ prev_char = char
77
  texts.append(''.join(decoded))
78
+
79
+ return texts[0] if len(texts) == 1 else texts
80
 
81
  def preprocess_image(image):
82
  """Convert input to model-compatible format"""
 
90
 
91
  def predict(image):
92
  try:
93
+ print("\n=== New Prediction ===") # Debug separator
94
+
95
+ # 1. Log input type
96
+ print(f"Input type: {type(image)}")
97
  if isinstance(image, dict):
98
+ print(f"Dict keys: {image.keys()}")
99
  image = image['image'] if 'image' in image else image['data']
100
+
101
+ # 2. Convert to PIL Image
102
  if not isinstance(image, Image.Image):
103
+ print("Converting to PIL Image...")
104
+ try:
105
+ image = Image.fromarray(image)
106
+ except Exception as conv_err:
107
+ print(f"Conversion error: {conv_err}")
108
+ return f"Image conversion failed: {conv_err}"
109
+
110
+ # 3. Verify image
111
+ print(f"Image mode: {image.mode}, size: {image.size}")
112
+ if image.mode != 'L':
113
+ print("Converting to grayscale...")
114
+ image = image.convert('L')
115
+
116
+ # 4. Preprocess
117
+ try:
118
+ image_tensor = preprocess_image(image)
119
+ print(f"Tensor shape: {image_tensor.shape}")
120
+ except Exception as preprocess_err:
121
+ print(f"Preprocessing error: {preprocess_err}")
122
+ return f"Preprocessing failed: {preprocess_err}"
123
 
124
+ # 5. Predict
 
125
  with torch.no_grad():
126
  outputs = model(image_tensor)
127
+ print(f"Raw model output shape: {outputs.shape}")
128
  prediction = decode_predictions(outputs)
129
+ print(f"Final prediction: {prediction}")
130
 
131
  return prediction
132
 
133
  except Exception as e:
134
+ error_msg = f"Full error: {str(e)}"
135
+ print(error_msg)
136
+ return error_msg
137
 
138
  # --------------------------
139
  # Gradio Interface