from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi.responses import HTMLResponse from transformers import YolosForObjectDetection, YolosImageProcessor from PIL import Image, ImageDraw import torch import io import base64 from starlette.requests import Request from fastapi.templating import Jinja2Templates import httpx app = FastAPI() # Create a FastAPI instance templates = Jinja2Templates(directory="templates") # Create a Jinja2Templates instance for handling HTML templates # Initialize YOLOS model and image processor yolos_model = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny') yolos_image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny") # Define a route for the main HTML page @app.get("/", response_class=HTMLResponse) async def main(request: Request): return templates.TemplateResponse("index.html", {"request": request}) # Define a route for handling object detection from a submitted form @app.post("/", response_class=HTMLResponse) async def post_detect_objects(request: Request, url: str = Form(...)): try: # Download the image from the specified URL async with httpx.AsyncClient() as client: response = await client.get(url) response.raise_for_status() # Raise an exception if there is an error in the request content = response.content image = Image.open(io.BytesIO(content)) # Preprocess the image using the YOLOS image processor inputs = yolos_image_processor(images=image, return_tensors="pt") # Run the YOLOS model on the preprocessed image outputs = yolos_model(**inputs) # Post-process the object detection results target_sizes = torch.tensor([image.size[::-1]]) results = yolos_image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0] # Draw bounding boxes on the image for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): image_draw = ImageDraw.Draw(image) image_draw.rectangle(box.tolist(), outline="red", width=2) image_draw.text((box[0], box[1]), f"{yolos_model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red") # Save the modified image to a byte sequence image_byte_array = io.BytesIO() image.save(image_byte_array, format="PNG") # Return the image as a response with content type "image/png" return templates.TemplateResponse("result.html", {"request": request, "image": base64.b64encode(image_byte_array.getvalue()).decode()}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")