fasthtmltestV2 / app.py
lyimo's picture
Update app.py
c6e1fa7 verified
from fasthtml.common import *
from fasthtml.core import serve
from fastai.vision.all import *
import io
import base64
from PIL import Image as PILImageLib
from starlette.datastructures import UploadFile
# Load the model
learn = load_learner('model.pkl')
labels = learn.dls.vocab
# Create the app with nice styling
app, rt = fast_app(
hdrs=(
Style("""
.prediction {
padding: 10px;
margin: 6px 0;
border-radius: 6px;
display: flex;
justify-content: space-between;
}
.top-prediction {
background-color: rgba(72, 187, 120, 0.15);
font-weight: bold;
}
.preview-img {
max-width: 300px;
border-radius: 8px;
margin: 15px 0;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.loader {
border: 4px solid #f3f3f3;
border-top: 4px solid #3498db;
border-radius: 50%;
width: 30px;
height: 30px;
animation: spin 2s linear infinite;
margin: 20px auto;
display: none;
}
.htmx-request .loader {
display: block;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.drop-container {
border: 2px dashed #ccc;
border-radius: 8px;
padding: 20px;
text-align: center;
margin-bottom: 20px;
transition: all 0.3s;
}
.drop-container:hover {
border-color: #1095c1;
background-color: rgba(16, 149, 193, 0.05);
}
""")
)
)
@rt
def index():
return Titled(
"Image Classifier",
Container(
H1("Image Classifier"),
P("Upload an image and our model will classify what's in it."),
Div(
Form(
Div(
Input(type="file", name="file", accept="image/*", required=True, id="file-upload"),
P("Drag and drop an image or click to select", cls="text-muted"),
cls="drop-container"
),
Button("Classify", type="submit"),
hx_post="/classify",
hx_target="#result",
hx_encoding="multipart/form-data"
),
Div(cls="loader"),
Div(id="result")
)
)
)
@rt("/classify")
async def post(file: UploadFile):
# Read the file
contents = await file.read()
# Process with FastAI
img = PILImage.create(io.BytesIO(contents))
img = img.resize((512, 512))
# Make prediction
pred, pred_idx, probs = learn.predict(img)
# Get top results
results = [(labels[i], float(probs[i])) for i in range(len(labels))]
results.sort(key=lambda x: x[1], reverse=True)
top_results = results[:3]
# Convert image for display
buffered = io.BytesIO()
pil_img = PILImageLib.fromarray(img.numpy().astype('uint8'))
pil_img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
# Return the result
return Card(
H3("Classification Results"),
Img(src=f"data:image/jpeg;base64,{img_str}", cls="preview-img"),
Div(
*[Div(
Span(label),
Span(f"{prob:.2%}"),
cls=f"prediction {'top-prediction' if i == 0 else ''}"
) for i, (label, prob) in enumerate(top_results)]
),
Button("Classify Another", hx_get="/", hx_target="body")
)
if __name__ == "__main__":
# Start the server with Hugging Face port
serve(port=7860)