Spaces:
Running
Running
T2LIPthedeveloper
commited on
Commit
·
ca4bd13
1
Parent(s):
6055356
Test deployment
Browse files- .gitignore +4 -0
- classification_model.pth +3 -0
- classification_model.py +70 -0
- main.py +31 -0
- ocr_model.py +36 -0
- requirements.txt +7 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
__pycache__
|
| 3 |
+
*.pyc
|
| 4 |
+
*.gradio
|
classification_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74a6325f0e12a18554e204106551536a00fafe003294deb6836fc09082d2a8ee
|
| 3 |
+
size 32390942
|
classification_model.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import timm
|
| 8 |
+
|
| 9 |
+
# Define your pill classes (ensure this matches your training setup)
|
| 10 |
+
PILL_CLASSES = {
|
| 11 |
+
'acc': 0, 'advil': 1, 'akineton': 2, 'algoflex': 3, 'algopyrin': 4, 'ambroxol': 5,
|
| 12 |
+
'apranax': 6, 'aspirin': 7, 'atoris': 8, 'atorvastatin': 9, 'betaloc': 10,
|
| 13 |
+
'bila': 11, 'c': 12, 'calci': 13, 'cataflam': 14, 'cetirizin': 15, 'co': 16,
|
| 14 |
+
'cold': 17, 'coldrex': 18, 'concor': 19, 'condrosulf': 20, 'controloc': 21,
|
| 15 |
+
'covercard': 22, 'coverex': 23, 'diclopram': 24, 'donalgin': 25, 'dorithricin': 26,
|
| 16 |
+
'doxazosin': 27, 'dulodet': 28, 'dulsevia': 29, 'enterol': 30, 'escitil': 31,
|
| 17 |
+
'favipiravir': 32, 'frontin': 33, 'furon': 34, 'ibumax': 35, 'indastad': 36,
|
| 18 |
+
'jutavit': 37, 'kalcium': 38, 'kalium': 39, 'ketodex': 40, 'koleszterin': 41,
|
| 19 |
+
'l': 42, 'lactamed': 43, 'lactiv': 44, 'laresin': 45, 'letrox': 46, 'lordestin': 47,
|
| 20 |
+
'magne': 48, 'mebucain': 49, 'merckformin': 50, 'meridian': 51, 'metothyrin': 52,
|
| 21 |
+
'mezym': 53, 'milgamma': 54, 'milurit': 55, 'naprosyn': 56, 'narva': 57,
|
| 22 |
+
'naturland': 58, 'nebivolol': 59, 'neo': 60, 'no': 61, 'noclaud': 62,
|
| 23 |
+
'nolpaza': 63, 'nootropil': 64, 'normodipine': 65, 'novo': 66, 'nurofen': 67,
|
| 24 |
+
'ocutein': 68, 'olicard': 69, 'panangin': 70, 'pantoprazol': 71, 'provera': 72,
|
| 25 |
+
'quamatel': 73, 'reasec': 74, 'revicet': 75, 'rhinathiol': 76, 'rubophen': 77,
|
| 26 |
+
'salazopyrin': 78, 'sedatif': 79, 'semicillin': 80, 'sicor': 81, 'sinupret': 82,
|
| 27 |
+
'sirdalud': 83, 'strepfen': 84, 'strepsils': 85, 'syncumar': 86, 'teva': 87,
|
| 28 |
+
'theospirex': 88, 'tricovel': 89, 'tritace': 90, 'urotrin': 91, 'urzinol': 92,
|
| 29 |
+
'valeriana': 93, 'verospiron': 94, 'vita': 95, 'vitamin': 96, 'voltaren': 97,
|
| 30 |
+
'xeter': 98, 'zadex': 99
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Set device to CPU
|
| 34 |
+
device = torch.device("cpu")
|
| 35 |
+
|
| 36 |
+
# Instantiate the model architecture (same as training)
|
| 37 |
+
model = timm.create_model("rexnet_150", pretrained=True, num_classes=len(PILL_CLASSES))
|
| 38 |
+
model.to(device)
|
| 39 |
+
|
| 40 |
+
# Load the trained state dict
|
| 41 |
+
model_path = os.path.join("classification_model.pth")
|
| 42 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 43 |
+
model.eval()
|
| 44 |
+
|
| 45 |
+
# Define image transformations
|
| 46 |
+
transform = transforms.Compose([
|
| 47 |
+
transforms.Resize((224, 224)),
|
| 48 |
+
transforms.ToTensor(),
|
| 49 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 50 |
+
])
|
| 51 |
+
|
| 52 |
+
def classify_medicine(image_bytes):
|
| 53 |
+
"""Convert image bytes to prediction using the loaded model."""
|
| 54 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 55 |
+
input_tensor = transform(image).unsqueeze(0).to(device)
|
| 56 |
+
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
output = model(input_tensor)
|
| 59 |
+
|
| 60 |
+
probabilities = F.softmax(output[0], dim=0)
|
| 61 |
+
class_index = torch.argmax(probabilities).item()
|
| 62 |
+
confidence = probabilities[class_index].item()
|
| 63 |
+
|
| 64 |
+
# Invert the PILL_CLASSES dictionary for easy lookup
|
| 65 |
+
PILL_CLASSES_INVERTED = {v: k for k, v in PILL_CLASSES.items()}
|
| 66 |
+
pill_class = PILL_CLASSES_INVERTED.get(class_index, "Unknown")
|
| 67 |
+
|
| 68 |
+
return {"class_index": class_index, "pill_class": pill_class, "confidence": confidence}
|
| 69 |
+
|
| 70 |
+
export = classify_medicine
|
main.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from classification_model import classify_medicine
|
| 3 |
+
from ocr_model import perform_ocr
|
| 4 |
+
|
| 5 |
+
def classify_image(image):
|
| 6 |
+
with open(image, "rb") as f:
|
| 7 |
+
image_bytes = f.read()
|
| 8 |
+
result = classify_medicine(image_bytes)
|
| 9 |
+
return f"Class: {result['pill_class']}, Confidence: {result['confidence']:.2f}"
|
| 10 |
+
|
| 11 |
+
def ocr_image(image):
|
| 12 |
+
with open(image, "rb") as f:
|
| 13 |
+
image_bytes = f.read()
|
| 14 |
+
return perform_ocr(image_bytes)
|
| 15 |
+
|
| 16 |
+
with gr.Blocks() as app:
|
| 17 |
+
gr.Markdown("## Medicine Classification and OCR App")
|
| 18 |
+
|
| 19 |
+
with gr.Tab("Classify Medicine"):
|
| 20 |
+
image_input = gr.Image(type="filepath")
|
| 21 |
+
classify_button = gr.Button("Classify")
|
| 22 |
+
output_text = gr.Textbox()
|
| 23 |
+
classify_button.click(classify_image, inputs=image_input, outputs=output_text)
|
| 24 |
+
|
| 25 |
+
with gr.Tab("OCR Extraction"):
|
| 26 |
+
image_input_ocr = gr.Image(type="filepath")
|
| 27 |
+
ocr_button = gr.Button("Extract Text")
|
| 28 |
+
ocr_output = gr.Textbox()
|
| 29 |
+
ocr_button.click(ocr_image, inputs=image_input_ocr, outputs=ocr_output)
|
| 30 |
+
|
| 31 |
+
app.launch()
|
ocr_model.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import requests
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def perform_ocr(image_bytes):
|
| 6 |
+
if not image_bytes:
|
| 7 |
+
raise ValueError("Empty image bytes provided")
|
| 8 |
+
# Validate image bytes
|
| 9 |
+
try:
|
| 10 |
+
Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 11 |
+
except Exception as e:
|
| 12 |
+
raise ValueError(f"Invalid image bytes provided: {e}")
|
| 13 |
+
|
| 14 |
+
# OCR.space API endpoint and payload (using the free 'helloworld' key)
|
| 15 |
+
api_url = "https://api.ocr.space/parse/image"
|
| 16 |
+
payload = {
|
| 17 |
+
'apikey': 'helloworld', # Free API key with usage limits
|
| 18 |
+
'language': 'eng'
|
| 19 |
+
}
|
| 20 |
+
files = {
|
| 21 |
+
'file': ('image.jpg', image_bytes)
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
response = requests.post(api_url, data=payload, files=files)
|
| 25 |
+
result = response.json()
|
| 26 |
+
|
| 27 |
+
if result.get("IsErroredOnProcessing"):
|
| 28 |
+
error = result.get("ErrorMessage") or "Unknown error"
|
| 29 |
+
raise ValueError(f"OCR processing error: {error}")
|
| 30 |
+
|
| 31 |
+
parsed_text = result.get("ParsedResults")[0].get("ParsedText", "")
|
| 32 |
+
paragraphs = parsed_text.split('\n')
|
| 33 |
+
formatted_text = "\n\n".join(p.strip() for p in paragraphs if p.strip())
|
| 34 |
+
return formatted_text
|
| 35 |
+
|
| 36 |
+
export = perform_ocr
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
transformers # For Hugging Face models
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
pillow # For image processing
|
| 6 |
+
python-dotenv # For loading environment variables
|
| 7 |
+
timm
|