Spaces:
Sleeping
Sleeping
Create summarizer_tool.py
Browse files- summarizer_tool.py +463 -0
summarizer_tool.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# summarizer_tool.py
|
| 2 |
+
|
| 3 |
+
# --- Core Imports ---
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, \
|
| 7 |
+
AutoModelForSpeechSeq2Seq, AutoModelForImageClassification, AutoModelForObjectDetection
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from pydub import AudioSegment
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
import numpy as np
|
| 12 |
+
import io
|
| 13 |
+
import logging
|
| 14 |
+
import re
|
| 15 |
+
import tempfile
|
| 16 |
+
import json # Added for handling JSON output consistently
|
| 17 |
+
|
| 18 |
+
# --- Langchain Imports ---
|
| 19 |
+
# Ensure these are correct based on Langchain's modularization
|
| 20 |
+
from langchain_community.vectorstores import FAISS
|
| 21 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 22 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter # This one is still in langchain
|
| 23 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 24 |
+
from langchain.chains import RetrievalQA
|
| 25 |
+
|
| 26 |
+
# --- Other Imports ---
|
| 27 |
+
from gtts import gTTS
|
| 28 |
+
from datasets import load_dataset, Audio # Added for dataset loading
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 32 |
+
|
| 33 |
+
# --- Global Cache for Pipelines ---
|
| 34 |
+
# This prevents reloading the same model multiple times
|
| 35 |
+
_pipeline_cache = {}
|
| 36 |
+
|
| 37 |
+
def get_pipeline(task_name, model_name=None, **kwargs):
|
| 38 |
+
"""
|
| 39 |
+
Retrieves a Hugging Face pipeline, caching it for efficiency.
|
| 40 |
+
"""
|
| 41 |
+
# Create a unique key for the cache based on task, model, and kwargs
|
| 42 |
+
cache_key = f"{task_name}-{model_name}-{hash(frozenset(kwargs.items()))}"
|
| 43 |
+
if cache_key not in _pipeline_cache:
|
| 44 |
+
logging.info(f"Loading pipeline for task '{task_name}' with model '{model_name}'...")
|
| 45 |
+
if model_name:
|
| 46 |
+
_pipeline_cache[cache_key] = pipeline(task_name, model=model_name, **kwargs)
|
| 47 |
+
else:
|
| 48 |
+
_pipeline_cache[cache_key] = pipeline(task_name, **kwargs) # Uses default model for task
|
| 49 |
+
logging.info(f"Pipeline '{task_name}' loaded.")
|
| 50 |
+
return _pipeline_cache[cache_key]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# --- Main Dispatcher Class ---
|
| 54 |
+
class AllInOneDispatcher:
|
| 55 |
+
def __init__(self):
|
| 56 |
+
logging.info("Initializing AllInOneDispatcher...")
|
| 57 |
+
self.memory = [] # For storing interaction history (optional)
|
| 58 |
+
|
| 59 |
+
# Define default models for various tasks.
|
| 60 |
+
# These will be loaded on demand via get_pipeline.
|
| 61 |
+
self.default_models = {
|
| 62 |
+
"sentiment-analysis": "distilbert-base-uncased-finetuned-sst-2-english",
|
| 63 |
+
"summarization": "sshleifer/distilbart-cnn-12-6",
|
| 64 |
+
"text-generation": "gpt2",
|
| 65 |
+
"translation_en_to_fr": "Helsinki-NLP/opus-mt-en-fr",
|
| 66 |
+
"image-classification": "google/vit-base-patch16-224",
|
| 67 |
+
"object-detection": "facebook/detr-resnet-50",
|
| 68 |
+
"automatic-speech-recognition": "openai/whisper-tiny.en", # For English ASR
|
| 69 |
+
# Add other models/tasks as needed
|
| 70 |
+
}
|
| 71 |
+
logging.info("AllInOneDispatcher initialized.")
|
| 72 |
+
|
| 73 |
+
def _get_task_pipeline(self, task: str, model_name: str = None):
|
| 74 |
+
"""Helper to get a cached pipeline for a specific task."""
|
| 75 |
+
final_model_name = model_name if model_name else self.default_models.get(task)
|
| 76 |
+
if not final_model_name:
|
| 77 |
+
raise ValueError(f"No default model specified for task '{task}'. Please provide `model_name` or add to default_models.")
|
| 78 |
+
return get_pipeline(task, model_name=final_model_name)
|
| 79 |
+
|
| 80 |
+
def _is_file(self, path):
|
| 81 |
+
"""Checks if the given path exists and is a file."""
|
| 82 |
+
return os.path.exists(path) and os.path.isfile(path)
|
| 83 |
+
|
| 84 |
+
def handle_text(self, text: str, task: str = "sentiment-analysis", **kwargs):
|
| 85 |
+
"""Processes text input for a given NLP task."""
|
| 86 |
+
if not isinstance(text, str):
|
| 87 |
+
raise TypeError("Text input must be a string.")
|
| 88 |
+
logging.info(f"Handling text for task: {task}")
|
| 89 |
+
pipeline_obj = self._get_task_pipeline(task)
|
| 90 |
+
result = pipeline_obj(text, **kwargs)
|
| 91 |
+
self.memory.append({"task": task, "input": text, "output": result})
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
def handle_image(self, path: str, task: str = "image-classification", **kwargs):
|
| 95 |
+
"""Processes image file input for a given computer vision task."""
|
| 96 |
+
if not self._is_file(path):
|
| 97 |
+
raise FileNotFoundError(f"Image file not found: {path}")
|
| 98 |
+
logging.info(f"Handling image for task: {task}")
|
| 99 |
+
try:
|
| 100 |
+
image = Image.open(path)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
raise ValueError(f"Could not open image file: {e}")
|
| 103 |
+
pipeline_obj = self._get_task_pipeline(task)
|
| 104 |
+
result = pipeline_obj(image, **kwargs)
|
| 105 |
+
self.memory.append({"task": task, "input": path, "output": result})
|
| 106 |
+
return result
|
| 107 |
+
|
| 108 |
+
def handle_audio(self, path: str, task: str = "automatic-speech-recognition", **kwargs):
|
| 109 |
+
"""Processes audio file input for a given audio task."""
|
| 110 |
+
if not self._is_file(path):
|
| 111 |
+
raise FileNotFoundError(f"Audio file not found: {path}")
|
| 112 |
+
logging.info(f"Handling audio for task: {task}")
|
| 113 |
+
|
| 114 |
+
# Whisper models expect audio in a specific format (16kHz, mono, float32)
|
| 115 |
+
try:
|
| 116 |
+
audio = AudioSegment.from_file(path)
|
| 117 |
+
audio = audio.set_channels(1).set_frame_rate(16000) # Convert to mono, 16kHz
|
| 118 |
+
|
| 119 |
+
buffer = io.BytesIO()
|
| 120 |
+
audio.export(buffer, format="wav") # Export to WAV in memory
|
| 121 |
+
buffer.seek(0) # Rewind buffer
|
| 122 |
+
|
| 123 |
+
array, sampling_rate = sf.read(buffer) # Read with soundfile
|
| 124 |
+
if array.dtype != np.float32:
|
| 125 |
+
array = array.astype(np.float32) # Ensure float32
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logging.error(f"Error preparing audio file for processing: {e}")
|
| 129 |
+
raise ValueError(f"Could not prepare audio file: {e}. Ensure ffmpeg is installed system-wide.")
|
| 130 |
+
|
| 131 |
+
pipeline_obj = self._get_task_pipeline(task)
|
| 132 |
+
result = pipeline_obj(array.tolist(), sampling_rate=sampling_rate, **kwargs)
|
| 133 |
+
self.memory.append({"task": task, "input": path, "output": result})
|
| 134 |
+
return result
|
| 135 |
+
|
| 136 |
+
def handle_video(self, path: str):
|
| 137 |
+
"""
|
| 138 |
+
Processes video file input. This is a limited implementation:
|
| 139 |
+
Extracts first few frames for image analysis and audio for ASR.
|
| 140 |
+
Requires OpenCV (cv2) and system-wide ffmpeg.
|
| 141 |
+
"""
|
| 142 |
+
if not self._is_file(path):
|
| 143 |
+
raise FileNotFoundError(f"Video file not found: {path}")
|
| 144 |
+
logging.info(f"Handling video: {path}")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
import cv2
|
| 148 |
+
except ImportError:
|
| 149 |
+
raise ImportError("OpenCV (cv2) not installed. Install with `pip install opencv-python` for video processing.")
|
| 150 |
+
|
| 151 |
+
frames = []
|
| 152 |
+
cap = cv2.VideoCapture(path)
|
| 153 |
+
if not cap.isOpened():
|
| 154 |
+
raise ValueError(f"Could not open video file: {path}")
|
| 155 |
+
|
| 156 |
+
while cap.isOpened():
|
| 157 |
+
ret, frame = cap.read()
|
| 158 |
+
if not ret:
|
| 159 |
+
break
|
| 160 |
+
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) # Convert BGR to RGB for PIL
|
| 161 |
+
if len(frames) >= 5: break # Process only first 5 frames for efficiency
|
| 162 |
+
cap.release()
|
| 163 |
+
|
| 164 |
+
# Extract audio from video
|
| 165 |
+
audio_temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
|
| 166 |
+
try:
|
| 167 |
+
# Using os.system for ffmpeg call requires ffmpeg to be in PATH
|
| 168 |
+
# This is a common way but can be less robust than a Python wrapper.
|
| 169 |
+
# Hugging Face Spaces typically has ffmpeg.
|
| 170 |
+
os.system(f"ffmpeg -i \"{path}\" -q:a 0 -map a \"{audio_temp_path}\" -y")
|
| 171 |
+
if not os.path.exists(audio_temp_path) or os.path.getsize(audio_temp_path) == 0:
|
| 172 |
+
raise RuntimeError("FFmpeg failed to extract audio or extracted empty audio.")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logging.error(f"FFmpeg audio extraction failed: {e}")
|
| 175 |
+
audio_temp_path = None # Indicate failure
|
| 176 |
+
|
| 177 |
+
image_result = None
|
| 178 |
+
audio_result = None
|
| 179 |
+
|
| 180 |
+
if frames:
|
| 181 |
+
try:
|
| 182 |
+
# Process the first frame for image classification
|
| 183 |
+
image_result = self.handle_image(frames[0], task="image-classification")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logging.warning(f"Failed to process video frame for image classification: {e}")
|
| 186 |
+
|
| 187 |
+
if audio_temp_path:
|
| 188 |
+
try:
|
| 189 |
+
# Process the extracted audio for ASR
|
| 190 |
+
audio_result = self.handle_audio(audio_temp_path, task="automatic-speech-recognition")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logging.warning(f"Failed to process extracted audio from video: {e}")
|
| 193 |
+
finally:
|
| 194 |
+
if os.path.exists(audio_temp_path):
|
| 195 |
+
os.remove(audio_temp_path) # Clean up temp audio file
|
| 196 |
+
|
| 197 |
+
result = {"image_analysis": image_result, "audio_analysis": audio_result}
|
| 198 |
+
self.memory.append({"task": "video_analysis", "input": path, "output": result})
|
| 199 |
+
return result
|
| 200 |
+
|
| 201 |
+
def handle_pdf(self, path: str):
|
| 202 |
+
"""Processes PDF file for summarization using RAG."""
|
| 203 |
+
if not self._is_file(path):
|
| 204 |
+
raise FileNotFoundError(f"PDF file not found: {path}")
|
| 205 |
+
logging.info(f"Handling PDF: {path}")
|
| 206 |
+
|
| 207 |
+
# RAG components
|
| 208 |
+
try:
|
| 209 |
+
loader = PyPDFLoader(path)
|
| 210 |
+
docs = loader.load()
|
| 211 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 212 |
+
split_docs = splitter.split_documents(docs)
|
| 213 |
+
embeddings = HuggingFaceEmbeddings()
|
| 214 |
+
vectorstore = FAISS.from_documents(split_docs, embeddings)
|
| 215 |
+
# Using a text-generation pipeline as the LLM for RetrievalQA
|
| 216 |
+
qa_llm = self._get_task_pipeline("text-generation", model_name="gpt2") # Using a smaller model for RAG LLM
|
| 217 |
+
qa_chain = RetrievalQA.from_chain_type(llm=qa_llm, retriever=vectorstore.as_retriever())
|
| 218 |
+
result = qa_chain.run("Summarize this document")
|
| 219 |
+
self.memory.append({"task": "pdf_summarization", "input": path, "output": result})
|
| 220 |
+
return result
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logging.error(f"Error processing PDF: {e}")
|
| 223 |
+
raise ValueError(f"Could not process PDF: {e}. Ensure PDF is valid and Langchain dependencies are met.")
|
| 224 |
+
|
| 225 |
+
def handle_tts(self, text: str, lang: str = 'en'):
|
| 226 |
+
"""Converts text to speech and returns the path to the generated audio file."""
|
| 227 |
+
if not isinstance(text, str):
|
| 228 |
+
raise TypeError("Text input for TTS must be a string.")
|
| 229 |
+
logging.info(f"Handling TTS for text: '{text[:50]}...'")
|
| 230 |
+
tts = gTTS(text=text, lang=lang)
|
| 231 |
+
temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
|
| 232 |
+
tts.save(temp_path)
|
| 233 |
+
self.memory.append({"task": "tts", "input": text, "output": temp_path})
|
| 234 |
+
return temp_path
|
| 235 |
+
|
| 236 |
+
def process_dataset_from_hub(self, dataset_name: str, subset_name: str, split: str, column_to_process: str, task: str, num_samples: int = 5):
|
| 237 |
+
"""
|
| 238 |
+
Loads a dataset from Hugging Face Hub, processes a specified column
|
| 239 |
+
for a given task, and returns results for a limited number of samples.
|
| 240 |
+
"""
|
| 241 |
+
logging.info(f"Attempting to load dataset '{dataset_name}' (subset: {subset_name}, split: {split})...")
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
# Load dataset. Using streaming=True for potentially very large datasets
|
| 245 |
+
# and then taking a few examples. trust_remote_code is important for some datasets.
|
| 246 |
+
if subset_name.strip():
|
| 247 |
+
dataset = load_dataset(dataset_name, subset_name, split=split, streaming=True, trust_remote_code=True)
|
| 248 |
+
else:
|
| 249 |
+
dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
|
| 250 |
+
|
| 251 |
+
logging.info(f"Dataset '{dataset_name}' loaded. Processing {num_samples} samples from column '{column_to_process}' for task '{task}'.")
|
| 252 |
+
|
| 253 |
+
processed_results = []
|
| 254 |
+
for i, example in enumerate(dataset):
|
| 255 |
+
if i >= num_samples:
|
| 256 |
+
break # Stop after processing desired number of samples
|
| 257 |
+
|
| 258 |
+
if column_to_process not in example:
|
| 259 |
+
processed_results.append({
|
| 260 |
+
"sample_index": i,
|
| 261 |
+
"status": "skipped",
|
| 262 |
+
"reason": f"Column '{column_to_process}' not found in this sample."
|
| 263 |
+
})
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
input_data_for_processing = example[column_to_process]
|
| 267 |
+
temp_file_to_clean = None # To track temporary files for cleanup
|
| 268 |
+
|
| 269 |
+
# Determine the actual data type and prepare for self.process
|
| 270 |
+
# Hugging Face datasets often load audio/image as specific objects/dicts
|
| 271 |
+
if isinstance(input_data_for_processing, str):
|
| 272 |
+
# It's already a string, assume text or a path
|
| 273 |
+
pass
|
| 274 |
+
elif isinstance(input_data_for_processing, dict) and 'array' in input_data_for_processing and 'sampling_rate' in input_data_for_processing:
|
| 275 |
+
# This is an audio object from datasets library
|
| 276 |
+
# Save to a temporary WAV file for self.handle_audio
|
| 277 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
|
| 278 |
+
sf.write(tmp_audio.name, input_data_for_processing['array'], input_data_for_processing['sampling_rate'])
|
| 279 |
+
input_data_for_processing = tmp_audio.name
|
| 280 |
+
temp_file_to_clean = tmp_audio.name
|
| 281 |
+
elif isinstance(input_data_for_processing, Image.Image):
|
| 282 |
+
# This is a PIL Image object from datasets library
|
| 283 |
+
# Save to a temporary PNG file for self.handle_image
|
| 284 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image:
|
| 285 |
+
input_data_for_processing.save(tmp_image.name)
|
| 286 |
+
input_data_for_processing = tmp_image.name
|
| 287 |
+
temp_file_to_clean = tmp_image.name
|
| 288 |
+
else:
|
| 289 |
+
processed_results.append({
|
| 290 |
+
"sample_index": i,
|
| 291 |
+
"status": "error",
|
| 292 |
+
"reason": f"Unsupported data type in column '{column_to_process}': {type(input_data_for_processing)}"
|
| 293 |
+
})
|
| 294 |
+
continue # Skip to next sample
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
# Call the general process method of the dispatcher
|
| 298 |
+
single_result = self.process(input_data_for_processing, task=task)
|
| 299 |
+
processed_results.append({
|
| 300 |
+
"sample_index": i,
|
| 301 |
+
"original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
|
| 302 |
+
"processed_result": single_result
|
| 303 |
+
})
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logging.error(f"Error processing sample {i} from dataset: {e}")
|
| 306 |
+
processed_results.append({
|
| 307 |
+
"sample_index": i,
|
| 308 |
+
"original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
|
| 309 |
+
"status": "error",
|
| 310 |
+
"reason": str(e)
|
| 311 |
+
})
|
| 312 |
+
finally:
|
| 313 |
+
if temp_file_to_clean and os.path.exists(temp_file_to_clean):
|
| 314 |
+
os.remove(temp_file_to_clean) # Clean up temporary file
|
| 315 |
+
|
| 316 |
+
return processed_results
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
logging.error(f"Failed to load or iterate dataset: {e}")
|
| 320 |
+
return [{"error": f"Failed to load or process dataset: {e}"}]
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def process(self, input_data, task=None, **kwargs):
|
| 324 |
+
"""
|
| 325 |
+
Main entry point for the AI tool. Tries to determine input type and
|
| 326 |
+
dispatches to the appropriate processing function.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
input_data: Can be raw text (str) or a file path (str) for image/audio/video/pdf.
|
| 330 |
+
task (str, optional): The specific AI task to perform.
|
| 331 |
+
Required for non-text inputs.
|
| 332 |
+
For text, it defaults to "sentiment-analysis".
|
| 333 |
+
**kwargs: Additional arguments to pass to the specific handler or pipeline.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
The result from the AI model, or a file path for TTS.
|
| 337 |
+
"""
|
| 338 |
+
if not isinstance(input_data, str):
|
| 339 |
+
raise TypeError("Input data must be a string (raw text or file path).")
|
| 340 |
+
|
| 341 |
+
if self._is_file(input_data):
|
| 342 |
+
file_extension = input_data.split('.')[-1].lower()
|
| 343 |
+
|
| 344 |
+
if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']:
|
| 345 |
+
if not task: task = "image-classification"
|
| 346 |
+
return self.handle_image(input_data, task=task, **kwargs)
|
| 347 |
+
elif file_extension in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
|
| 348 |
+
if not task: task = "automatic-speech-recognition"
|
| 349 |
+
return self.handle_audio(input_data, task=task, **kwargs)
|
| 350 |
+
elif file_extension in ['mp4', 'mov', 'avi', 'mkv']:
|
| 351 |
+
# Video processing is a separate, more complex handler
|
| 352 |
+
return self.handle_video(input_data)
|
| 353 |
+
elif file_extension == 'pdf':
|
| 354 |
+
return self.handle_pdf(input_data)
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError(f"Unsupported file type: .{file_extension}. Or specify task for this file.")
|
| 357 |
+
else:
|
| 358 |
+
# Assume it's raw text if not a file path
|
| 359 |
+
if task == "tts":
|
| 360 |
+
return self.handle_tts(input_data, **kwargs)
|
| 361 |
+
if not task: task = "sentiment-analysis" # Default text task
|
| 362 |
+
return self.handle_text(input_data, task=task, **kwargs)
|
| 363 |
+
|
| 364 |
+
# --- Example Usage (for local testing only - will be skipped when imported by app.py) ---
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
logging.info("Running local example usage of AllInOneDispatcher.")
|
| 367 |
+
dispatcher = AllInOneDispatcher()
|
| 368 |
+
|
| 369 |
+
# Text Examples
|
| 370 |
+
print("\n--- Text Examples ---")
|
| 371 |
+
text_input = "The new movie was absolutely fantastic! I highly recommend it."
|
| 372 |
+
print(f"Input: '{text_input}'")
|
| 373 |
+
print(f"Sentiment: {dispatcher.process(text_input, task='sentiment-analysis')}")
|
| 374 |
+
|
| 375 |
+
text_to_summarize = ("Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—demonstrated by machines, as opposed to intelligence displayed by animals or humans. Example tasks in which AI is used include speech recognition, computer vision, translation, and others. AI applications include advanced web search engines, recommendation systems, understanding human speech, self-driving cars, and competing at the highest level in strategic game systems.")
|
| 376 |
+
print(f"\nInput: '{text_to_summarize}'")
|
| 377 |
+
print(f"Summary: {dispatcher.process(text_to_summarize, task='summarization', max_length=50, min_length=10)}")
|
| 378 |
+
|
| 379 |
+
text_to_generate = "In a galaxy far, far away, a brave knight"
|
| 380 |
+
print(f"\nInput: '{text_to_generate}'")
|
| 381 |
+
generated_text = dispatcher.process(text_to_generate, task='text-generation', max_new_tokens=50, num_return_sequences=1)
|
| 382 |
+
print(f"Generated Text: {generated_text[0]['generated_text']}")
|
| 383 |
+
|
| 384 |
+
tts_text = "Hello, this is a test from the AI assistant."
|
| 385 |
+
tts_path = dispatcher.process(tts_text, task="tts", lang="en")
|
| 386 |
+
print(f"TTS audio saved to: {tts_path}")
|
| 387 |
+
if os.path.exists(tts_path):
|
| 388 |
+
os.remove(tts_path) # Clean up generated audio
|
| 389 |
+
|
| 390 |
+
# Image Examples (requires dummy image or real path)
|
| 391 |
+
dummy_image_path = "dummy_image_for_test.png"
|
| 392 |
+
if not os.path.exists(dummy_image_path):
|
| 393 |
+
try:
|
| 394 |
+
Image.new('RGB', (100, 100), color='blue').save(dummy_image_path)
|
| 395 |
+
print(f"\nCreated dummy image: {dummy_image_path}")
|
| 396 |
+
except Exception as e:
|
| 397 |
+
print(f"\nCould not create dummy image: {e}. Skipping image example.")
|
| 398 |
+
dummy_image_path = None
|
| 399 |
+
|
| 400 |
+
if dummy_image_path and os.path.exists(dummy_image_path):
|
| 401 |
+
print(f"\nImage Input: {dummy_image_path}")
|
| 402 |
+
try:
|
| 403 |
+
print(f"Image Classification: {dispatcher.process(dummy_image_path, task='image-classification')}")
|
| 404 |
+
except Exception as e:
|
| 405 |
+
print(f"Error during image classification: {e}")
|
| 406 |
+
finally:
|
| 407 |
+
os.remove(dummy_image_path)
|
| 408 |
+
|
| 409 |
+
# Audio Examples (requires dummy audio or real path, and ffmpeg)
|
| 410 |
+
dummy_audio_path = "dummy_audio_for_test.wav"
|
| 411 |
+
if not os.path.exists(dummy_audio_path):
|
| 412 |
+
try:
|
| 413 |
+
from pydub.generators import Sine
|
| 414 |
+
sine_wave = Sine(440).to_audio_segment(duration=1000)
|
| 415 |
+
sine_wave.export(dummy_audio_path, format="wav")
|
| 416 |
+
print(f"\nCreated dummy audio: {dummy_audio_path}")
|
| 417 |
+
except ImportError:
|
| 418 |
+
print("\npydub not installed. Skipping dummy audio creation.")
|
| 419 |
+
dummy_audio_path = None
|
| 420 |
+
except Exception as e:
|
| 421 |
+
print(f"\nCould not create dummy audio: {e}. Skipping audio example.")
|
| 422 |
+
dummy_audio_path = None
|
| 423 |
+
|
| 424 |
+
if dummy_audio_path and os.path.exists(dummy_audio_path):
|
| 425 |
+
print(f"\nAudio Input: {dummy_audio_path}")
|
| 426 |
+
try:
|
| 427 |
+
transcription = dispatcher.process(dummy_audio_path, task='automatic-speech-recognition')
|
| 428 |
+
print(f"Audio Transcription: {transcription['text']}")
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"Error during audio transcription: {e}")
|
| 431 |
+
finally:
|
| 432 |
+
os.remove(dummy_audio_path)
|
| 433 |
+
|
| 434 |
+
# PDF Example (requires a dummy PDF or real path)
|
| 435 |
+
# Note: Creating a dummy PDF programmatically is complex.
|
| 436 |
+
# For testing, you'd need to place a small PDF file in the same directory.
|
| 437 |
+
# dummy_pdf_path = "dummy.pdf"
|
| 438 |
+
# if os.path.exists(dummy_pdf_path):
|
| 439 |
+
# print(f"\nPDF Input: {dummy_pdf_path}")
|
| 440 |
+
# try:
|
| 441 |
+
# print(f"PDF Summary: {dispatcher.process(dummy_pdf_path, task='pdf')}")
|
| 442 |
+
# except Exception as e:
|
| 443 |
+
# print(f"Error during PDF processing: {e}")
|
| 444 |
+
# else:
|
| 445 |
+
# print(f"\nSkipping PDF example: '{dummy_pdf_path}' not found. Please create one for testing.")
|
| 446 |
+
|
| 447 |
+
# Dataset Example (requires internet access)
|
| 448 |
+
print("\n--- Dataset Example (will process a few samples) ---")
|
| 449 |
+
try:
|
| 450 |
+
dataset_results = dispatcher.process_dataset_from_hub(
|
| 451 |
+
dataset_name="glue",
|
| 452 |
+
subset_name="sst2",
|
| 453 |
+
split="train",
|
| 454 |
+
column_to_process="sentence",
|
| 455 |
+
task="sentiment-analysis",
|
| 456 |
+
num_samples=2
|
| 457 |
+
)
|
| 458 |
+
print(f"Dataset Processing Results: {json.dumps(dataset_results, indent=2)}")
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"Error during dataset processing example: {e}")
|
| 461 |
+
|
| 462 |
+
logging.info("Local example usage complete.")
|
| 463 |
+
|