File size: 4,161 Bytes
90c3a00
bc738b0
9982566
0ae5be4
88de4d9
bc738b0
88de4d9
 
 
 
f96e87a
88de4d9
d305c5e
 
71a90e9
 
73f86ab
 
 
 
 
 
 
 
 
d305c5e
 
 
 
 
 
 
1067d57
 
 
 
0318ae9
1067d57
d305c5e
88de4d9
 
1067d57
 
 
88de4d9
 
 
 
 
2a5ea2c
a2dba21
2a5ea2c
 
a2dba21
2a5ea2c
 
d305c5e
1067d57
d305c5e
 
 
 
 
2e72fc2
88de4d9
 
57a2a31
7e25eae
57a2a31
 
74db6a6
 
9391917
57a2a31
 
74db6a6
 
ffedde7
 
 
 
57a2a31
ffedde7
 
 
 
 
 
 
 
 
1d7f6bf
ffedde7
 
 
 
 
 
 
 
 
 
 
 
88de4d9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import gradio as gr
import numpy as np
from PIL import Image

# Load the trained model
model = tf.keras.models.load_model("prescription_classification_model.keras")

# Define the 78 drug names in the exact order corresponding to your model's output classes
CLASS_NAMES = [
    "Ace", "Aceta", "Alatrol", "Amodis", "Atrizin", "Axodin", "Az","Azithrocin", "Azyth",
     "Bacaid", "Backtone", "Baclofen", "Baclon", "Bacmax", "Beklo", "Bicozin",
    "Canazole", "Candinil", "Cetisoft", "Conaz", "Dancel", "Denixil", "Diflu",
    "Dinafex", "Disopan", "Esonix", "Esoral", "Etizin", "Exium", "Fenadin",
    "Fexofast", "Fexo", "Filmet", "Fixal", "Flamyd", "Flexibac", "Flexilax",
    "Flugal", "Ketocon", "Ketoral", "Ketotab", "Ketozol", "Leptic", "Lucan-R",
    "Lumona", "M-Kast", "Maxima", "Maxpro", "Metro", "Metsina", "Monas",
    "Montair", "Montene", "Montex", "Napa Extend", "Napa", "Nexcap", "Nexum",
    "Nidazyl", "Nizoder", "Odmon", "Omastin", "Opton", "Progut", "Provair",
    "Renova", "Rhinil", "Ritch", "Rivotril", "Romycin", "Rozith", "Sergel",
    "Tamen", "Telfast", "Tridosil", "Trilock", "Vifas", "Zithrin"
]

def decode_prediction(prediction):
    """
    Expects prediction to be a numpy array of shape (1, 78).
    It returns the drug name corresponding to the highest probability.
    """
    if prediction.shape != (1, 78):
        return "Error: Unexpected model output shape"

    predicted_index = np.argmax(prediction, axis=-1)[0]  # Get the index of the highest probability
    print ("Predicted Index:", predicted_index)
    return CLASS_NAMES[predicted_index]  # Return the corresponding drug name

# Function to preprocess the uploaded image
def preprocess_image(image):
    image = image.convert("RGB")  # Ensure 3 channels
    image = image.resize((64, 64))  # Match model input size
    image = np.array(image) / 255.0  # Normalize to [0,1]
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image

# Function to predict text from handwritten prescription
def predict_text(image):
    image = image.convert("RGB")
    image = image.resize((128, 64))  # Resize for better character segmentation
    image = np.array(image) / 255.0  # Normalize

    num_chars = 5  # Estimated number of characters in the word
    segment_width = image.shape[1] // num_chars  # Split image into equal parts

def predict_text(image):
    processed_image = preprocess_image(image)  # Ensure input is (64, 64, 3)
    prediction = model.predict(processed_image)
    
    # Decode the prediction to get the drug name
    predicted_text = decode_prediction(prediction)
    return predicted_text


# Gradio UI
# Load custom CSS from an external file
with open("style.css", "r") as f:
    custom_css = f.read()

# Create a base theme (without custom CSS)
theme_obj = gr.themes.Base(
    font=["Arial", "sans-serif"]
)

# Build Gradio interface with custom CSS applied via the Blocks css parameter
with gr.Blocks(theme=theme_obj, css=custom_css) as interface:
    
    # Header
    gr.Markdown(
        """
        <h1 style="text-align: center;">πŸ₯ Prescription Recognition AI</h1>
        <p style="text-align: center; font-size: 18px;">Upload or take a picture of a handwritten prescription. The AI will identify the drug name.</p>
        <hr>
        """,
        elem_id="header"
    )

    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ“Έ Upload or Capture an Image")
            image_input = gr.Image(type="pil", label="Upload or Capture Prescription")
        
        with gr.Column():
            gr.Markdown("### πŸ’Š Recognized Drug Name")
            output_text = gr.Textbox(label="Predicted Drug Name", interactive=False)

    # Process button
    with gr.Row():
        submit_button = gr.Button("πŸ” Identify Prescription", variant="primary")
    
    # On button click, predict the drug name
    submit_button.click(fn=predict_text, inputs=image_input, outputs=output_text)


if __name__ == "__main__":
    interface.launch()