File size: 2,014 Bytes
32a0eda
 
 
 
 
6f59750
32a0eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f59750
32a0eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f59750
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Created By: ishwor subedi
Date: 2024-08-13
"""
from PIL import Image
from fastapi import HTTPException, UploadFile, File
from fastapi.routing import APIRouter
from six import BytesIO

from src.models.models import ImageCaptionRequest, ImageGenerationRequest, LanguageTranslationRequest
from src.pipeline.image_processing_pipeline import ImageProcessingPipeline

image_processing_pipeline = ImageProcessingPipeline()
image_processing_router = APIRouter()


@image_processing_router.post("/generate_image")
async def generate_image(request: ImageGenerationRequest):
    try:
        image = image_processing_pipeline.generate_image(
            request.prompt,
            request.negative_prompt,
            request.style,
            request.use_negative_prompt,
            request.num_inference_steps,
            request.num_images_per_prompt,
            request.seed,
            request.width,
            request.height,
            request.guidance_scale,
            request.randomize_seed
        )
        base_64_image = BytesIO()
        image.save(base_64_image, format='PNG')
        base_64_image = base_64_image.getvalue()
        json = {'image': base_64_image, 'status_code': 200}

        return json
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@image_processing_router.post("/generate_caption")
async def generate_caption(request: ImageCaptionRequest, image: UploadFile = File(...)):
    image = Image.open(BytesIO(await image.read())).convert('RGB')

    try:
        caption = image_processing_pipeline.generate_caption(
            image,
            request.prompt,
            request.temperature,
            request.length_penalty,
            request.repetition_penalty,
            request.max_length,
            request.min_length,
            request.top_p
        )
        json = {'caption': caption, 'status_code': 200}
        return json
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))