File size: 1,160 Bytes
0ad27ae
 
b6faa5c
2449b1f
 
b6faa5c
 
 
 
6292093
b6faa5c
6292093
b6faa5c
3153924
81700f6
0ad27ae
 
 
 
 
 
 
 
81700f6
b6faa5c
3153924
81700f6
0ad27ae
 
 
 
 
 
 
 
 
61ff965
 
81700f6
 
 
 
61ff965
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""CLIP interface module"""

# libraries
from typing import Dict, List, Union

from PIL import Image

# modules
from src.core.logger import logger
from src.model.clip import ClipModel

MODEL = ClipModel()


def clean_text(text: List[str]) -> List[str]:
    """function to clean gradio input text

    Args:
        text (str): string of comma separated text

    Returns:
        List[str]: list of cleaned text
    """
    return list(map(lambda x: x.strip(), text))


def clip_demo_fn(image: Image.Image, text: Union[str, List[str]]) -> Dict[str, float]:
    """demo function for gradio interface

    Args:
        image (Image.Image): expects PIL image_
        text (str): string of comma separated text

    Returns:
        Dict[str, float]: dictionary of text classes and its associated probability
    """
    try:
        logger.info("demo function invoked")
        if isinstance(text, str):
            text = clean_text(text.split(","))
        if isinstance(text, list):
            text = clean_text(text)
        logger.debug("clean text: %s", text)
        return MODEL(image, text)
    finally:
        logger.info("demo function completed")