scripts / app.py
stshanks's picture
Update app.py
f96e87a verified
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import gradio as gr
import numpy as np
from PIL import Image
# Load the trained model
model = tf.keras.models.load_model("prescription_classification_model.keras")
# Define the 78 drug names in the exact order corresponding to your model's output classes
CLASS_NAMES = [
"Ace", "Aceta", "Alatrol", "Amodis", "Atrizin", "Axodin", "Az","Azithrocin", "Azyth",
"Bacaid", "Backtone", "Baclofen", "Baclon", "Bacmax", "Beklo", "Bicozin",
"Canazole", "Candinil", "Cetisoft", "Conaz", "Dancel", "Denixil", "Diflu",
"Dinafex", "Disopan", "Esonix", "Esoral", "Etizin", "Exium", "Fenadin",
"Fexofast", "Fexo", "Filmet", "Fixal", "Flamyd", "Flexibac", "Flexilax",
"Flugal", "Ketocon", "Ketoral", "Ketotab", "Ketozol", "Leptic", "Lucan-R",
"Lumona", "M-Kast", "Maxima", "Maxpro", "Metro", "Metsina", "Monas",
"Montair", "Montene", "Montex", "Napa Extend", "Napa", "Nexcap", "Nexum",
"Nidazyl", "Nizoder", "Odmon", "Omastin", "Opton", "Progut", "Provair",
"Renova", "Rhinil", "Ritch", "Rivotril", "Romycin", "Rozith", "Sergel",
"Tamen", "Telfast", "Tridosil", "Trilock", "Vifas", "Zithrin"
]
def decode_prediction(prediction):
"""
Expects prediction to be a numpy array of shape (1, 78).
It returns the drug name corresponding to the highest probability.
"""
if prediction.shape != (1, 78):
return "Error: Unexpected model output shape"
predicted_index = np.argmax(prediction, axis=-1)[0] # Get the index of the highest probability
print ("Predicted Index:", predicted_index)
return CLASS_NAMES[predicted_index] # Return the corresponding drug name
# Function to preprocess the uploaded image
def preprocess_image(image):
image = image.convert("RGB") # Ensure 3 channels
image = image.resize((64, 64)) # Match model input size
image = np.array(image) / 255.0 # Normalize to [0,1]
image = np.expand_dims(image, axis=0) # Add batch dimension
return image
# Function to predict text from handwritten prescription
def predict_text(image):
image = image.convert("RGB")
image = image.resize((128, 64)) # Resize for better character segmentation
image = np.array(image) / 255.0 # Normalize
num_chars = 5 # Estimated number of characters in the word
segment_width = image.shape[1] // num_chars # Split image into equal parts
def predict_text(image):
processed_image = preprocess_image(image) # Ensure input is (64, 64, 3)
prediction = model.predict(processed_image)
# Decode the prediction to get the drug name
predicted_text = decode_prediction(prediction)
return predicted_text
# Gradio UI
# Load custom CSS from an external file
with open("style.css", "r") as f:
custom_css = f.read()
# Create a base theme (without custom CSS)
theme_obj = gr.themes.Base(
font=["Arial", "sans-serif"]
)
# Build Gradio interface with custom CSS applied via the Blocks css parameter
with gr.Blocks(theme=theme_obj, css=custom_css) as interface:
# Header
gr.Markdown(
"""
<h1 style="text-align: center;">πŸ₯ Prescription Recognition AI</h1>
<p style="text-align: center; font-size: 18px;">Upload or take a picture of a handwritten prescription. The AI will identify the drug name.</p>
<hr>
""",
elem_id="header"
)
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“Έ Upload or Capture an Image")
image_input = gr.Image(type="pil", label="Upload or Capture Prescription")
with gr.Column():
gr.Markdown("### πŸ’Š Recognized Drug Name")
output_text = gr.Textbox(label="Predicted Drug Name", interactive=False)
# Process button
with gr.Row():
submit_button = gr.Button("πŸ” Identify Prescription", variant="primary")
# On button click, predict the drug name
submit_button.click(fn=predict_text, inputs=image_input, outputs=output_text)
if __name__ == "__main__":
interface.launch()