from fasthtml.common import * from fasthtml.core import serve from fastai.vision.all import * import io import base64 from PIL import Image as PILImageLib from starlette.datastructures import UploadFile # Load the model learn = load_learner('model.pkl') labels = learn.dls.vocab # Create the app with nice styling app, rt = fast_app( hdrs=( Style(""" .prediction { padding: 10px; margin: 6px 0; border-radius: 6px; display: flex; justify-content: space-between; } .top-prediction { background-color: rgba(72, 187, 120, 0.15); font-weight: bold; } .preview-img { max-width: 300px; border-radius: 8px; margin: 15px 0; box-shadow: 0 4px 8px rgba(0,0,0,0.1); } .loader { border: 4px solid #f3f3f3; border-top: 4px solid #3498db; border-radius: 50%; width: 30px; height: 30px; animation: spin 2s linear infinite; margin: 20px auto; display: none; } .htmx-request .loader { display: block; } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } .drop-container { border: 2px dashed #ccc; border-radius: 8px; padding: 20px; text-align: center; margin-bottom: 20px; transition: all 0.3s; } .drop-container:hover { border-color: #1095c1; background-color: rgba(16, 149, 193, 0.05); } """) ) ) @rt def index(): return Titled( "Image Classifier", Container( H1("Image Classifier"), P("Upload an image and our model will classify what's in it."), Div( Form( Div( Input(type="file", name="file", accept="image/*", required=True, id="file-upload"), P("Drag and drop an image or click to select", cls="text-muted"), cls="drop-container" ), Button("Classify", type="submit"), hx_post="/classify", hx_target="#result", hx_encoding="multipart/form-data" ), Div(cls="loader"), Div(id="result") ) ) ) @rt("/classify") async def post(file: UploadFile): # Read the file contents = await file.read() # Process with FastAI img = PILImage.create(io.BytesIO(contents)) img = img.resize((512, 512)) # Make prediction pred, pred_idx, probs = learn.predict(img) # Get top results results = [(labels[i], float(probs[i])) for i in range(len(labels))] results.sort(key=lambda x: x[1], reverse=True) top_results = results[:3] # Convert image for display buffered = io.BytesIO() pil_img = PILImageLib.fromarray(img.numpy().astype('uint8')) pil_img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') # Return the result return Card( H3("Classification Results"), Img(src=f"data:image/jpeg;base64,{img_str}", cls="preview-img"), Div( *[Div( Span(label), Span(f"{prob:.2%}"), cls=f"prediction {'top-prediction' if i == 0 else ''}" ) for i, (label, prob) in enumerate(top_results)] ), Button("Classify Another", hx_get="/", hx_target="body") ) if __name__ == "__main__": # Start the server with Hugging Face port serve(port=7860)