File size: 3,844 Bytes
6347d69
c6e1fa7
6347d69
 
 
 
53f5e9a
6347d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fasthtml.common import *
from fasthtml.core import serve
from fastai.vision.all import *
import io
import base64
from PIL import Image as PILImageLib
from starlette.datastructures import UploadFile

# Load the model
learn = load_learner('model.pkl')
labels = learn.dls.vocab

# Create the app with nice styling
app, rt = fast_app(
    hdrs=(
        Style("""
        .prediction {
            padding: 10px;
            margin: 6px 0;
            border-radius: 6px;
            display: flex;
            justify-content: space-between;
        }
        .top-prediction {
            background-color: rgba(72, 187, 120, 0.15);
            font-weight: bold;
        }
        .preview-img {
            max-width: 300px;
            border-radius: 8px;
            margin: 15px 0;
            box-shadow: 0 4px 8px rgba(0,0,0,0.1);
        }
        .loader {
            border: 4px solid #f3f3f3;
            border-top: 4px solid #3498db;
            border-radius: 50%;
            width: 30px;
            height: 30px;
            animation: spin 2s linear infinite;
            margin: 20px auto;
            display: none;
        }
        .htmx-request .loader {
            display: block;
        }
        @keyframes spin {
            0% { transform: rotate(0deg); }
            100% { transform: rotate(360deg); }
        }
        .drop-container {
            border: 2px dashed #ccc;
            border-radius: 8px;
            padding: 20px;
            text-align: center;
            margin-bottom: 20px;
            transition: all 0.3s;
        }
        .drop-container:hover {
            border-color: #1095c1;
            background-color: rgba(16, 149, 193, 0.05);
        }
        """)
    )
)

@rt
def index():
    return Titled(
        "Image Classifier",
        Container(
            H1("Image Classifier"),
            P("Upload an image and our model will classify what's in it."),
            Div(
                Form(
                    Div(
                        Input(type="file", name="file", accept="image/*", required=True, id="file-upload"),
                        P("Drag and drop an image or click to select", cls="text-muted"),
                        cls="drop-container"
                    ),
                    Button("Classify", type="submit"),
                    hx_post="/classify",
                    hx_target="#result",
                    hx_encoding="multipart/form-data"
                ),
                Div(cls="loader"),
                Div(id="result")
            )
        )
    )

@rt("/classify")
async def post(file: UploadFile):
    # Read the file
    contents = await file.read()
    
    # Process with FastAI
    img = PILImage.create(io.BytesIO(contents))
    img = img.resize((512, 512))
    
    # Make prediction
    pred, pred_idx, probs = learn.predict(img)
    
    # Get top results
    results = [(labels[i], float(probs[i])) for i in range(len(labels))]
    results.sort(key=lambda x: x[1], reverse=True)
    top_results = results[:3]
    
    # Convert image for display
    buffered = io.BytesIO()
    pil_img = PILImageLib.fromarray(img.numpy().astype('uint8'))
    pil_img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
    
    # Return the result
    return Card(
        H3("Classification Results"),
        Img(src=f"data:image/jpeg;base64,{img_str}", cls="preview-img"),
        Div(
            *[Div(
                Span(label),
                Span(f"{prob:.2%}"), 
                cls=f"prediction {'top-prediction' if i == 0 else ''}"
            ) for i, (label, prob) in enumerate(top_results)]
        ),
        Button("Classify Another", hx_get="/", hx_target="body")
    )

if __name__ == "__main__":
    # Start the server with Hugging Face port
    serve(port=7860)