File size: 19,500 Bytes
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
ccafac2
628e97e
 
 
ccafac2
 
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
ccafac2
 
628e97e
ccafac2
628e97e
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
 
628e97e
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
 
 
 
 
 
 
 
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccafac2
628e97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# summarizer_tool.py

# --- Core Imports ---
import os
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, \
    AutoModelForSpeechSeq2Seq, AutoModelForImageClassification, AutoModelForObjectDetection
from PIL import Image
from pydub import AudioSegment
import soundfile as sf
import numpy as np
import io
import logging
import re
import tempfile
import json # Added for handling JSON output consistently

# --- Langchain Imports ---
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline # <--- ADD THIS LINE

# --- Other Imports ---
from gtts import gTTS
from datasets import load_dataset, Audio # Added for dataset loading

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Global Cache for Pipelines ---
_pipeline_cache = {}

def get_pipeline(task_name, model_name=None, **kwargs):
    """
    Retrieves a Hugging Face pipeline, caching it for efficiency.
    """
    cache_key = f"{task_name}-{model_name}-{hash(frozenset(kwargs.items()))}"
    if cache_key not in _pipeline_cache:
        logging.info(f"Loading pipeline for task '{task_name}' with model '{model_name}'...")
        if model_name:
            _pipeline_cache[cache_key] = pipeline(task_name, model=model_name, **kwargs)
        else:
            _pipeline_cache[cache_key] = pipeline(task_name, **kwargs)
        logging.info(f"Pipeline '{task_name}' loaded.")
    return _pipeline_cache[cache_key]


# --- Main Dispatcher Class ---
class AllInOneDispatcher:
    def __init__(self):
        logging.info("Initializing AllInOneDispatcher...")
        self.memory = []

        self.default_models = {
            "sentiment-analysis": "distilbert-base-uncased-finetuned-sst-2-english",
            "summarization": "sshleifer/distilbart-cnn-12-6",
            "text-generation": "gpt2", # Keep gpt2 for general text generation
            "translation_en_to_fr": "Helsinki-NLP/opus-mt-en-fr",
            "image-classification": "google/vit-base-patch16-224",
            "object-detection": "facebook/detr-resnet-50",
            "automatic-speech-recognition": "openai/whisper-tiny.en",
            "rag-llm": "gpt2" # New default for the RAG LLM
        }
        logging.info("AllInOneDispatcher initialized.")

    def _get_task_pipeline(self, task: str, model_name: str = None):
        """Helper to get a cached pipeline for a specific task."""
        final_model_name = model_name if model_name else self.default_models.get(task)
        if not final_model_name:
            raise ValueError(f"No default model specified for task '{task}'. Please provide `model_name` or add to default_models.")
        return get_pipeline(task, model_name=final_model_name)

    def _is_file(self, path):
        return os.path.exists(path) and os.path.isfile(path)

    def handle_text(self, text: str, task: str = "sentiment-analysis", **kwargs):
        if not isinstance(text, str):
            raise TypeError("Text input must be a string.")
        logging.info(f"Handling text for task: {task}")
        pipeline_obj = self._get_task_pipeline(task)
        result = pipeline_obj(text, **kwargs)
        self.memory.append({"task": task, "input": text, "output": result})
        return result

    def handle_image(self, path: str, task: str = "image-classification", **kwargs):
        if not self._is_file(path):
            raise FileNotFoundError(f"Image file not found: {path}")
        logging.info(f"Handling image for task: {task}")
        try:
            image = Image.open(path)
        except Exception as e:
            raise ValueError(f"Could not open image file: {e}")
        pipeline_obj = self._get_task_pipeline(task)
        result = pipeline_obj(image, **kwargs)
        self.memory.append({"task": task, "input": path, "output": result})
        return result

    def handle_audio(self, path: str, task: str = "automatic-speech-recognition", **kwargs):
        if not self._is_file(path):
            raise FileNotFoundError(f"Audio file not found: {path}")
        logging.info(f"Handling audio for task: {task}")
        try:
            audio = AudioSegment.from_file(path)
            audio = audio.set_channels(1).set_frame_rate(16000)

            buffer = io.BytesIO()
            audio.export(buffer, format="wav")
            buffer.seek(0)

            array, sampling_rate = sf.read(buffer)
            if array.dtype != np.float32:
                array = array.astype(np.float32)

        except Exception as e:
            logging.error(f"Error preparing audio file for processing: {e}")
            raise ValueError(f"Could not prepare audio file: {e}. Ensure ffmpeg is installed system-wide.")

        pipeline_obj = self._get_task_pipeline(task)
        result = pipeline_obj(array.tolist(), sampling_rate=sampling_rate, **kwargs)
        self.memory.append({"task": task, "input": path, "output": result})
        return result

    def handle_video(self, path: str):
        if not self._is_file(path):
            raise FileNotFoundError(f"Video file not found: {path}")
        logging.info(f"Handling video: {path}")

        try:
            import cv2
        except ImportError:
            raise ImportError("OpenCV (cv2) not installed. Install with `pip install opencv-python` for video processing.")

        frames = []
        cap = cv2.VideoCapture(path)
        if not cap.isOpened():
            raise ValueError(f"Could not open video file: {path}")

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
            if len(frames) >= 5: break
        cap.release()

        audio_temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
        try:
            os.system(f"ffmpeg -i \"{path}\" -q:a 0 -map a \"{audio_temp_path}\" -y")
            if not os.path.exists(audio_temp_path) or os.path.getsize(audio_temp_path) == 0:
                raise RuntimeError("FFmpeg failed to extract audio or extracted empty audio.")
        except Exception as e:
            logging.error(f"FFmpeg audio extraction failed: {e}")
            audio_temp_path = None

        image_result = None
        audio_result = None

        if frames:
            try:
                image_result = self.handle_image(frames[0], task="image-classification")
            except Exception as e:
                logging.warning(f"Failed to process video frame for image classification: {e}")

        if audio_temp_path:
            try:
                audio_result = self.handle_audio(audio_temp_path, task="automatic-speech-recognition")
            except Exception as e:
                logging.warning(f"Failed to process extracted audio from video: {e}")
            finally:
                if os.path.exists(audio_temp_path):
                    os.remove(audio_temp_path)

        result = {"image_analysis": image_result, "audio_analysis": audio_result}
        self.memory.append({"task": "video_analysis", "input": path, "output": result})
        return result

    def handle_pdf(self, path: str):
        """Processes PDF file for summarization using RAG."""
        if not self._is_file(path):
            raise FileNotFoundError(f"PDF file not found: {path}")
        logging.info(f"Handling PDF: {path}")

        try:
            loader = PyPDFLoader(path)
            docs = loader.load()
            splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
            split_docs = splitter.split_documents(docs)
            embeddings = HuggingFaceEmbeddings()
            vectorstore = FAISS.from_documents(split_docs, embeddings)

            # --- FIX STARTS HERE ---
            # Get the text generation pipeline
            text_gen_pipeline = self._get_task_pipeline("text-generation", model_name=self.default_models["rag-llm"])
            # Wrap it with Langchain's HuggingFacePipeline
            qa_llm = HuggingFacePipeline(pipeline=text_gen_pipeline)
            # --- FIX ENDS HERE ---

            qa_chain = RetrievalQA.from_chain_type(llm=qa_llm, retriever=vectorstore.as_retriever())
            result = qa_chain.run("Summarize this document")
            self.memory.append({"task": "pdf_summarization", "input": path, "output": result})
            return result
        except Exception as e:
            logging.error(f"Error processing PDF: {e}")
            raise ValueError(f"Could not process PDF: {e}. Ensure PDF is valid and Langchain dependencies are met.")

    def handle_tts(self, text: str, lang: str = 'en'):
        if not isinstance(text, str):
            raise TypeError("Text input for TTS must be a string.")
        logging.info(f"Handling TTS for text: '{text[:50]}...'")
        tts = gTTS(text=text, lang=lang)
        temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
        tts.save(temp_path)
        self.memory.append({"task": "tts", "input": text, "output": temp_path})
        return temp_path

    def process_dataset_from_hub(self, dataset_name: str, subset_name: str, split: str, column_to_process: str, task: str, num_samples: int = 5):
        logging.info(f"Attempting to load dataset '{dataset_name}' (subset: {subset_name}, split: {split})...")
        
        try:
            if subset_name.strip():
                dataset = load_dataset(dataset_name, subset_name, split=split, streaming=True, trust_remote_code=True)
            else:
                dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
            
            logging.info(f"Dataset '{dataset_name}' loaded. Processing {num_samples} samples from column '{column_to_process}' for task '{task}'.")

            processed_results = []
            for i, example in enumerate(dataset):
                if i >= num_samples:
                    break

                if column_to_process not in example:
                    processed_results.append({
                        "sample_index": i,
                        "status": "skipped",
                        "reason": f"Column '{column_to_process}' not found in this sample."
                    })
                    continue

                input_data_for_processing = example[column_to_process]
                temp_file_to_clean = None

                if isinstance(input_data_for_processing, str):
                    pass
                elif isinstance(input_data_for_processing, dict) and 'array' in input_data_for_processing and 'sampling_rate' in input_data_for_processing:
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
                        sf.write(tmp_audio.name, input_data_for_processing['array'], input_data_for_processing['sampling_rate'])
                        input_data_for_processing = tmp_audio.name
                        temp_file_to_clean = tmp_audio.name
                elif isinstance(input_data_for_processing, Image.Image):
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image:
                        input_data_for_processing.save(tmp_image.name)
                        input_data_for_processing = tmp_image.name
                        temp_file_to_clean = tmp_image.name
                else:
                    processed_results.append({
                        "sample_index": i,
                        "status": "error",
                        "reason": f"Unsupported data type in column '{column_to_process}': {type(input_data_for_processing)}"
                    })
                    continue

                try:
                    single_result = self.process(input_data_for_processing, task=task)
                    processed_results.append({
                        "sample_index": i,
                        "original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
                        "processed_result": single_result
                    })
                except Exception as e:
                    logging.error(f"Error processing sample {i} from dataset: {e}")
                    processed_results.append({
                        "sample_index": i,
                        "original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
                        "status": "error",
                        "reason": str(e)
                    })
                finally:
                    if temp_file_to_clean and os.path.exists(temp_file_to_clean):
                        os.remove(temp_file_to_clean)

            return processed_results

        except Exception as e:
            logging.error(f"Failed to load or iterate dataset: {e}")
            return [{"error": f"Failed to load or process dataset: {e}"}]


    def process(self, input_data, task=None, **kwargs):
        if not isinstance(input_data, str):
            raise TypeError("Input data must be a string (raw text or file path).")

        if self._is_file(input_data):
            file_extension = input_data.split('.')[-1].lower()

            if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']:
                if not task: task = "image-classification"
                return self.handle_image(input_data, task=task, **kwargs)
            elif file_extension in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
                if not task: task = "automatic-speech-recognition"
                return self.handle_audio(input_data, task=task, **kwargs)
            elif file_extension in ['mp4', 'mov', 'avi', 'mkv']:
                return self.handle_video(input_data)
            elif file_extension == 'pdf':
                return self.handle_pdf(input_data)
            else:
                raise ValueError(f"Unsupported file type: .{file_extension}. Or specify task for this file.")
        else:
            if task == "tts":
                return self.handle_tts(input_data, **kwargs)
            if not task: task = "sentiment-analysis"
            return self.handle_text(input_data, task=task, **kwargs)

# --- Example Usage (for local testing only - will be skipped when imported by app.py) ---
if __name__ == "__main__":
    logging.info("Running local example usage of AllInOneDispatcher.")
    dispatcher = AllInOneDispatcher()

    # Text Examples
    print("\n--- Text Examples ---")
    text_input = "The new movie was absolutely fantastic! I highly recommend it."
    print(f"Input: '{text_input}'")
    print(f"Sentiment: {dispatcher.process(text_input, task='sentiment-analysis')}")

    text_to_summarize = ("Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—demonstrated by machines, as opposed to intelligence displayed by animals or humans. Example tasks in which AI is used include speech recognition, computer vision, translation, and others. AI applications include advanced web search engines, recommendation systems, understanding human speech, self-driving cars, and competing at the highest level in strategic game systems.")
    print(f"\nInput: '{text_to_summarize}'")
    print(f"Summary: {dispatcher.process(text_to_summarize, task='summarization', max_length=50, min_length=10)}")

    text_to_generate = "In a galaxy far, far away, a brave knight"
    print(f"\nInput: '{text_to_generate}'")
    generated_text = dispatcher.process(text_to_generate, task='text-generation', max_new_tokens=50, num_return_sequences=1)
    print(f"Generated Text: {generated_text[0]['generated_text']}")

    tts_text = "Hello, this is a test from the AI assistant."
    tts_path = dispatcher.process(tts_text, task="tts", lang="en")
    print(f"TTS audio saved to: {tts_path}")
    if os.path.exists(tts_path):
        os.remove(tts_path)

    # Image Examples (requires dummy image or real path)
    dummy_image_path = "dummy_image_for_test.png"
    if not os.path.exists(dummy_image_path):
        try:
            Image.new('RGB', (100, 100), color='blue').save(dummy_image_path)
            print(f"\nCreated dummy image: {dummy_image_path}")
        except Exception as e:
            print(f"\nCould not create dummy image: {e}. Skipping image example.")
            dummy_image_path = None

    if dummy_image_path and os.path.exists(dummy_image_path):
        print(f"\nImage Input: {dummy_image_path}")
        try:
            print(f"Image Classification: {dispatcher.process(dummy_image_path, task='image-classification')}")
        except Exception as e:
            print(f"Error during image classification: {e}")
        finally:
            os.remove(dummy_image_path)

    # Audio Examples (requires dummy audio or real path, and ffmpeg)
    dummy_audio_path = "dummy_audio_for_test.wav"
    if not os.path.exists(dummy_audio_path):
        try:
            from pydub.generators import Sine
            sine_wave = Sine(440).to_audio_segment(duration=1000)
            sine_wave.export(dummy_audio_path, format="wav")
            print(f"\nCreated dummy audio: {dummy_audio_path}")
        except ImportError:
            print("\npydub not installed. Skipping dummy audio creation.")
            dummy_audio_path = None
        except Exception as e:
            print(f"\nCould not create dummy audio: {e}. Skipping audio example.")
            dummy_audio_path = None

    if dummy_audio_path and os.path.exists(dummy_audio_path):
        print(f"\nAudio Input: {dummy_audio_path}")
        try:
            transcription = dispatcher.process(dummy_audio_path, task='automatic-speech-recognition')
            print(f"Audio Transcription: {transcription['text']}")
        except Exception as e:
            print(f"Error during audio transcription: {e}")
        finally:
            os.remove(dummy_audio_path)

    # PDF Example (requires a dummy PDF or real path)
    # For testing, you'd need to place a small PDF file in the same directory.
    # dummy_pdf_path = "dummy.pdf"
    # if os.path.exists(dummy_pdf_path):
    #     print(f"\nPDF Input: {dummy_pdf_path}")
    #     try:
    #         print(f"PDF Summary: {dispatcher.process(dummy_pdf_path, task='pdf')}")
    #     except Exception as e:
    #         print(f"Error during PDF processing: {e}")
    # else:
    #     print(f"\nSkipping PDF example: '{dummy_pdf_path}' not found. Please create one for testing.")

    # Dataset Example (requires internet access)
    print("\n--- Dataset Example (will process a few samples) ---")
    try:
        dataset_results = dispatcher.process_dataset_from_hub(
            dataset_name="glue",
            subset_name="sst2",
            split="train",
            column_to_process="sentence",
            task="sentiment-analysis",
            num_samples=2
        )
        print(f"Dataset Processing Results: {json.dumps(dataset_results, indent=2)}")
    except Exception as e:
        print(f"Error during dataset processing example: {e}")

    logging.info("Local example usage complete.")