File size: 3,207 Bytes
ca4bd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442ca22
ca4bd13
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import io
import os
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import timm

# Define your pill classes (ensure this matches your training setup)
PILL_CLASSES = {
    'acc': 0, 'advil': 1, 'akineton': 2, 'algoflex': 3, 'algopyrin': 4, 'ambroxol': 5, 
    'apranax': 6, 'aspirin': 7, 'atoris': 8, 'atorvastatin': 9, 'betaloc': 10, 
    'bila': 11, 'c': 12, 'calci': 13, 'cataflam': 14, 'cetirizin': 15, 'co': 16, 
    'cold': 17, 'coldrex': 18, 'concor': 19, 'condrosulf': 20, 'controloc': 21, 
    'covercard': 22, 'coverex': 23, 'diclopram': 24, 'donalgin': 25, 'dorithricin': 26, 
    'doxazosin': 27, 'dulodet': 28, 'dulsevia': 29, 'enterol': 30, 'escitil': 31, 
    'favipiravir': 32, 'frontin': 33, 'furon': 34, 'ibumax': 35, 'indastad': 36, 
    'jutavit': 37, 'kalcium': 38, 'kalium': 39, 'ketodex': 40, 'koleszterin': 41, 
    'l': 42, 'lactamed': 43, 'lactiv': 44, 'laresin': 45, 'letrox': 46, 'lordestin': 47, 
    'magne': 48, 'mebucain': 49, 'merckformin': 50, 'meridian': 51, 'metothyrin': 52, 
    'mezym': 53, 'milgamma': 54, 'milurit': 55, 'naprosyn': 56, 'narva': 57, 
    'naturland': 58, 'nebivolol': 59, 'neo': 60, 'no': 61, 'noclaud': 62, 
    'nolpaza': 63, 'nootropil': 64, 'normodipine': 65, 'novo': 66, 'nurofen': 67, 
    'ocutein': 68, 'olicard': 69, 'panangin': 70, 'pantoprazol': 71, 'provera': 72, 
    'quamatel': 73, 'reasec': 74, 'revicet': 75, 'rhinathiol': 76, 'rubophen': 77, 
    'salazopyrin': 78, 'sedatif': 79, 'semicillin': 80, 'sicor': 81, 'sinupret': 82, 
    'sirdalud': 83, 'strepfen': 84, 'strepsils': 85, 'syncumar': 86, 'teva': 87, 
    'theospirex': 88, 'tricovel': 89, 'tritace': 90, 'urotrin': 91, 'urzinol': 92, 
    'valeriana': 93, 'verospiron': 94, 'vita': 95, 'vitamin': 96, 'voltaren': 97, 
    'xeter': 98, 'zadex': 99
}

# Set device to CPU
device = torch.device("cpu")

# Instantiate the model architecture (same as training)
model = timm.create_model("rexnet_150", pretrained=True, num_classes=len(PILL_CLASSES))
model.to(device)

# Load the trained state dict
model_path = os.path.join("classification_model.pth")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def classify_medicine(image_bytes):
    """Convert image bytes to prediction using the loaded model."""
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(input_tensor)
    
    probabilities = F.softmax(output[0], dim=0)
    class_index = torch.argmax(probabilities).item()
    confidence = probabilities[class_index].item()
    
    # Invert the PILL_CLASSES dictionary for easy lookup
    PILL_CLASSES_INVERTED = {v: k for k, v in PILL_CLASSES.items()}
    pill_class = PILL_CLASSES_INVERTED.get(class_index, "Unknown")
    
    return {"class_index": class_index, "pill_class": pill_class, "confidence": confidence}

export = classify_medicine