YOLOS-tiny-Docker / app-local.py
crisrm128's picture
Uploaded new version
3ed724c
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import HTMLResponse
from transformers import YolosForObjectDetection, YolosImageProcessor
from PIL import Image, ImageDraw
import torch
import io
import base64
from starlette.requests import Request
from fastapi.templating import Jinja2Templates
import httpx
app = FastAPI() # Create a FastAPI instance
templates = Jinja2Templates(directory="templates") # Create a Jinja2Templates instance for handling HTML templates
# Initialize YOLOS model and image processor
yolos_model = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny')
yolos_image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
# Define a route for the main HTML page
@app.get("/", response_class=HTMLResponse)
async def main(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# Define a route for handling object detection from a submitted form
@app.post("/", response_class=HTMLResponse)
async def post_detect_objects(request: Request, url: str = Form(...)):
try:
# Download the image from the specified URL
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status() # Raise an exception if there is an error in the request
content = response.content
image = Image.open(io.BytesIO(content))
# Preprocess the image using the YOLOS image processor
inputs = yolos_image_processor(images=image, return_tensors="pt")
# Run the YOLOS model on the preprocessed image
outputs = yolos_model(**inputs)
# Post-process the object detection results
target_sizes = torch.tensor([image.size[::-1]])
results = yolos_image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
# Draw bounding boxes on the image
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
image_draw = ImageDraw.Draw(image)
image_draw.rectangle(box.tolist(), outline="red", width=2)
image_draw.text((box[0], box[1]), f"{yolos_model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
# Save the modified image to a byte sequence
image_byte_array = io.BytesIO()
image.save(image_byte_array, format="PNG")
# Return the image as a response with content type "image/png"
return templates.TemplateResponse("result.html", {"request": request, "image": base64.b64encode(image_byte_array.getvalue()).decode()})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")