File size: 5,356 Bytes
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f2fd9
e568430
 
 
 
 
 
 
 
85f2fd9
e568430
 
 
 
85f2fd9
e568430
 
85f2fd9
e568430
 
 
 
85f2fd9
e568430
 
 
85f2fd9
43839ca
e568430
 
 
 
 
 
 
 
85f2fd9
e568430
 
 
 
 
 
 
 
 
 
b4f9ff5
 
 
e568430
85f2fd9
 
 
 
 
 
 
 
 
 
 
 
e568430
85f2fd9
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Unified multimodal processing service for text, audio, and image inputs."""

from functools import lru_cache
from typing import Any

import structlog
from gradio.data_classes import FileData

from src.services.audio_processing import AudioService, get_audio_service
from src.services.image_ocr import ImageOCRService, get_image_ocr_service
from src.utils.config import settings

logger = structlog.get_logger(__name__)


class MultimodalService:
    """Unified multimodal processing service."""

    def __init__(
        self,
        audio_service: AudioService | None = None,
        ocr_service: ImageOCRService | None = None,
    ) -> None:
        """Initialize multimodal service.

        Args:
            audio_service: Audio service instance (default: get_audio_service())
            ocr_service: Image OCR service instance (default: get_image_ocr_service())
        """
        self.audio = audio_service or get_audio_service()
        self.ocr = ocr_service or get_image_ocr_service()

    async def process_multimodal_input(
        self,
        text: str,
        files: list[FileData] | None = None,
        audio_input: tuple[int, Any] | None = None,
        hf_token: str | None = None,
        prepend_multimodal: bool = True,
    ) -> str:
        """Process multimodal input (text + images + audio) and return combined text.

        Args:
            text: Text input string
            files: List of uploaded files (images, audio, etc.)
            audio_input: Audio input tuple (sample_rate, audio_array)
            hf_token: HuggingFace token for authenticated Gradio Spaces
            prepend_multimodal: If True, prepend audio/image text to original text; otherwise append

        Returns:
            Combined text from all inputs
        """
        multimodal_parts: list[str] = []
        text_parts: list[str] = []

        # Process audio input first
        if audio_input is not None and settings.enable_audio_input:
            try:
                transcribed = await self.audio.process_audio_input(audio_input, hf_token=hf_token)
                if transcribed:
                    multimodal_parts.append(transcribed)
            except Exception as e:
                logger.warning("audio_processing_failed", error=str(e))

        # Process uploaded files (images and audio files)
        if files and settings.enable_image_input:
            for file_data in files:
                file_path = file_data.path if isinstance(file_data, FileData) else str(file_data)

                # Check if it's an image
                if self._is_image_file(file_path):
                    try:
                        extracted_text = await self.ocr.extract_text(file_path, hf_token=hf_token)
                        if extracted_text:
                            multimodal_parts.append(extracted_text)
                    except Exception as e:
                        logger.warning("image_ocr_failed", file_path=file_path, error=str(e))

                # Check if it's an audio file
                elif self._is_audio_file(file_path):
                    try:
                        # For audio files, we'd need to load and transcribe
                        # For now, log a warning
                        logger.warning("audio_file_upload_not_supported", file_path=file_path)
                    except Exception as e:
                        logger.warning(
                            "audio_file_processing_failed", file_path=file_path, error=str(e)
                        )

        # Add original text if present
        if text and text.strip():
            text_parts.append(text.strip())

        # Combine parts based on prepend_multimodal flag
        if prepend_multimodal:
            # Prepend: multimodal content first, then original text
            combined_parts = multimodal_parts + text_parts
        else:
            # Append: original text first, then multimodal content
            combined_parts = text_parts + multimodal_parts

        # Combine all text parts
        combined_text = "\n\n".join(combined_parts) if combined_parts else ""

        logger.info(
            "multimodal_input_processed",
            text_length=len(combined_text),
            num_files=len(files) if files else 0,
            has_audio=audio_input is not None,
        )

        return combined_text

    def _is_image_file(self, file_path: str) -> bool:
        """Check if file is an image.

        Args:
            file_path: Path to file

        Returns:
            True if file is an image
        """
        image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".tif"}
        return any(file_path.lower().endswith(ext) for ext in image_extensions)

    def _is_audio_file(self, file_path: str) -> bool:
        """Check if file is an audio file.

        Args:
            file_path: Path to file

        Returns:
            True if file is an audio file
        """
        audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
        return any(file_path.lower().endswith(ext) for ext in audio_extensions)


@lru_cache(maxsize=1)
def get_multimodal_service() -> MultimodalService:
    """Get or create singleton multimodal service instance.

    Returns:
        MultimodalService instance
    """
    return MultimodalService()