| from transformers import pipeline |
| import logging |
| from fastapi import Request, HTTPException |
| import base64 |
|
|
|
|
| class TextToImageTaskService: |
|
|
| __logger: logging.Logger |
|
|
| def __init__(self, logger: logging.Logger): |
| self.__logger = logger |
|
|
| async def get_encoded_image( |
| self, |
| request: Request |
| ) -> str: |
| content_type = request.headers.get("content-type", "") |
| if content_type.startswith("multipart/form-data"): |
| form = await request.form() |
| image = form.get("image") |
| if image: |
| image_bytes = await image.read() |
| return base64.b64encode(image_bytes).decode("utf-8") |
| if content_type.startswith("image/"): |
| image_bytes = await request.body() |
| return base64.b64encode(image_bytes).decode("utf-8") |
|
|
| raise HTTPException(status_code=400, detail="Unsupported content type") |
|
|
| async def extract( |
| self, |
| request: Request, |
| model_name: str |
| ): |
| encoded_image = await self.get_encoded_image(request) |
|
|
| try: |
| pipe = pipeline("image-to-text", model=model_name, use_fast=True) |
| except Exception as e: |
| self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") |
| raise HTTPException( |
| status_code=404, |
| detail=f"Model '{model_name}' could not be loaded: {str(e)}" |
| ) |
|
|
| try: |
| result = pipe(encoded_image) |
| except Exception as e: |
| self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Inference failed: {str(e)}" |
| ) |
|
|
| return result |