T2LIPthedeveloper commited on
Commit
ca4bd13
·
1 Parent(s): 6055356

Test deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .venv
2
+ __pycache__
3
+ *.pyc
4
+ *.gradio
classification_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74a6325f0e12a18554e204106551536a00fafe003294deb6836fc09082d2a8ee
3
+ size 32390942
classification_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import timm
8
+
9
+ # Define your pill classes (ensure this matches your training setup)
10
+ PILL_CLASSES = {
11
+ 'acc': 0, 'advil': 1, 'akineton': 2, 'algoflex': 3, 'algopyrin': 4, 'ambroxol': 5,
12
+ 'apranax': 6, 'aspirin': 7, 'atoris': 8, 'atorvastatin': 9, 'betaloc': 10,
13
+ 'bila': 11, 'c': 12, 'calci': 13, 'cataflam': 14, 'cetirizin': 15, 'co': 16,
14
+ 'cold': 17, 'coldrex': 18, 'concor': 19, 'condrosulf': 20, 'controloc': 21,
15
+ 'covercard': 22, 'coverex': 23, 'diclopram': 24, 'donalgin': 25, 'dorithricin': 26,
16
+ 'doxazosin': 27, 'dulodet': 28, 'dulsevia': 29, 'enterol': 30, 'escitil': 31,
17
+ 'favipiravir': 32, 'frontin': 33, 'furon': 34, 'ibumax': 35, 'indastad': 36,
18
+ 'jutavit': 37, 'kalcium': 38, 'kalium': 39, 'ketodex': 40, 'koleszterin': 41,
19
+ 'l': 42, 'lactamed': 43, 'lactiv': 44, 'laresin': 45, 'letrox': 46, 'lordestin': 47,
20
+ 'magne': 48, 'mebucain': 49, 'merckformin': 50, 'meridian': 51, 'metothyrin': 52,
21
+ 'mezym': 53, 'milgamma': 54, 'milurit': 55, 'naprosyn': 56, 'narva': 57,
22
+ 'naturland': 58, 'nebivolol': 59, 'neo': 60, 'no': 61, 'noclaud': 62,
23
+ 'nolpaza': 63, 'nootropil': 64, 'normodipine': 65, 'novo': 66, 'nurofen': 67,
24
+ 'ocutein': 68, 'olicard': 69, 'panangin': 70, 'pantoprazol': 71, 'provera': 72,
25
+ 'quamatel': 73, 'reasec': 74, 'revicet': 75, 'rhinathiol': 76, 'rubophen': 77,
26
+ 'salazopyrin': 78, 'sedatif': 79, 'semicillin': 80, 'sicor': 81, 'sinupret': 82,
27
+ 'sirdalud': 83, 'strepfen': 84, 'strepsils': 85, 'syncumar': 86, 'teva': 87,
28
+ 'theospirex': 88, 'tricovel': 89, 'tritace': 90, 'urotrin': 91, 'urzinol': 92,
29
+ 'valeriana': 93, 'verospiron': 94, 'vita': 95, 'vitamin': 96, 'voltaren': 97,
30
+ 'xeter': 98, 'zadex': 99
31
+ }
32
+
33
+ # Set device to CPU
34
+ device = torch.device("cpu")
35
+
36
+ # Instantiate the model architecture (same as training)
37
+ model = timm.create_model("rexnet_150", pretrained=True, num_classes=len(PILL_CLASSES))
38
+ model.to(device)
39
+
40
+ # Load the trained state dict
41
+ model_path = os.path.join("classification_model.pth")
42
+ model.load_state_dict(torch.load(model_path, map_location=device))
43
+ model.eval()
44
+
45
+ # Define image transformations
46
+ transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ])
51
+
52
+ def classify_medicine(image_bytes):
53
+ """Convert image bytes to prediction using the loaded model."""
54
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
55
+ input_tensor = transform(image).unsqueeze(0).to(device)
56
+
57
+ with torch.no_grad():
58
+ output = model(input_tensor)
59
+
60
+ probabilities = F.softmax(output[0], dim=0)
61
+ class_index = torch.argmax(probabilities).item()
62
+ confidence = probabilities[class_index].item()
63
+
64
+ # Invert the PILL_CLASSES dictionary for easy lookup
65
+ PILL_CLASSES_INVERTED = {v: k for k, v in PILL_CLASSES.items()}
66
+ pill_class = PILL_CLASSES_INVERTED.get(class_index, "Unknown")
67
+
68
+ return {"class_index": class_index, "pill_class": pill_class, "confidence": confidence}
69
+
70
+ export = classify_medicine
main.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from classification_model import classify_medicine
3
+ from ocr_model import perform_ocr
4
+
5
+ def classify_image(image):
6
+ with open(image, "rb") as f:
7
+ image_bytes = f.read()
8
+ result = classify_medicine(image_bytes)
9
+ return f"Class: {result['pill_class']}, Confidence: {result['confidence']:.2f}"
10
+
11
+ def ocr_image(image):
12
+ with open(image, "rb") as f:
13
+ image_bytes = f.read()
14
+ return perform_ocr(image_bytes)
15
+
16
+ with gr.Blocks() as app:
17
+ gr.Markdown("## Medicine Classification and OCR App")
18
+
19
+ with gr.Tab("Classify Medicine"):
20
+ image_input = gr.Image(type="filepath")
21
+ classify_button = gr.Button("Classify")
22
+ output_text = gr.Textbox()
23
+ classify_button.click(classify_image, inputs=image_input, outputs=output_text)
24
+
25
+ with gr.Tab("OCR Extraction"):
26
+ image_input_ocr = gr.Image(type="filepath")
27
+ ocr_button = gr.Button("Extract Text")
28
+ ocr_output = gr.Textbox()
29
+ ocr_button.click(ocr_image, inputs=image_input_ocr, outputs=ocr_output)
30
+
31
+ app.launch()
ocr_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ from PIL import Image
4
+
5
+ def perform_ocr(image_bytes):
6
+ if not image_bytes:
7
+ raise ValueError("Empty image bytes provided")
8
+ # Validate image bytes
9
+ try:
10
+ Image.open(io.BytesIO(image_bytes)).convert("RGB")
11
+ except Exception as e:
12
+ raise ValueError(f"Invalid image bytes provided: {e}")
13
+
14
+ # OCR.space API endpoint and payload (using the free 'helloworld' key)
15
+ api_url = "https://api.ocr.space/parse/image"
16
+ payload = {
17
+ 'apikey': 'helloworld', # Free API key with usage limits
18
+ 'language': 'eng'
19
+ }
20
+ files = {
21
+ 'file': ('image.jpg', image_bytes)
22
+ }
23
+
24
+ response = requests.post(api_url, data=payload, files=files)
25
+ result = response.json()
26
+
27
+ if result.get("IsErroredOnProcessing"):
28
+ error = result.get("ErrorMessage") or "Unknown error"
29
+ raise ValueError(f"OCR processing error: {error}")
30
+
31
+ parsed_text = result.get("ParsedResults")[0].get("ParsedText", "")
32
+ paragraphs = parsed_text.split('\n')
33
+ formatted_text = "\n\n".join(p.strip() for p in paragraphs if p.strip())
34
+ return formatted_text
35
+
36
+ export = perform_ocr
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers # For Hugging Face models
3
+ torch
4
+ torchvision
5
+ pillow # For image processing
6
+ python-dotenv # For loading environment variables
7
+ timm