Umer2762's picture
Update predict.py
95335f9 verified
import gradio as gr
import numpy as np
from keras.api.models import load_model
import string
class TextTypeLangModel:
def __init__(self, model_path, csv_path=None):
# Load pre-trained model
self.model = load_model(model_path)
self.characters = string.ascii_letters + string.digits + " " + \
"آبپچڈڑڤکگہہٹژزسٹطظعغفقکگلاںمںنۓہھوؤ" + \
"ےیئؤٹپجچحخدڈذرزسشصضطظعغفقکلمنوٕں" + \
"ۓۓہ۔،؛؟"
self.num_chars = len(self.characters) + 1 # Extra for blank
self.char_to_index = {c: i+1 for i, c in enumerate(self.characters)}
self.index_to_char = {i+1: c for i, c in enumerate(self.characters)}
def encode_text(self, text, max_len=10):
text = text[:max_len].ljust(max_len) # Pad or trim text
return [self.char_to_index.get(c, 0) for c in text] # Convert to indices
def preprocess_image(self, image):
# Directly use the PIL image object
image = image.convert("RGB") # Ensure image is in RGB mode
image = image.resize((128, 128))
image = np.array(image) / 255.0 # Normalize
return image
def predict(self, image):
image = self.preprocess_image(image)
image = np.expand_dims(image, axis=0) # Add batch dimension
pred_text, pred_type, pred_lang = self.model.predict(image)
# Decode text prediction
pred_text = ''.join(self.index_to_char.get(np.argmax(pred_text[0][i]), '') for i in range(10))
return pred_text.strip(), np.argmax(pred_type), np.argmax(pred_lang)
def get_type_string(int_type):
type_switch = {
0: "Medicine",
1: "Dosage",
2: "Diagnostic",
3: "Symptoms",
4: "Personal Info",
5: "Numeric Data",
6: "Text"
}
return type_switch.get(int_type, "Unknown")
def predict_text_type_lang(image):
model = TextTypeLangModel("./model/text_type_lang_model.h5")
predicted_text, predicted_type, predicted_language = model.predict(image)
predicted_type_str = get_type_string(predicted_type)
predicted_language_str = "English" if predicted_language == 0 else "Urdu"
return predicted_text, predicted_type_str, predicted_language_str
# Gradio interface
iface = gr.Interface(
fn=predict_text_type_lang,
inputs=gr.Image(type="pil"),
outputs=["text", "text", "text"],
title="Text Type & Language Prediction",
description="Upload an image to predict the extracted text, type, and language.",
)
iface.launch(debug=True)