Modified_tool / summarizer_tool.py
TarSh8654's picture
Create summarizer_tool.py
628e97e verified
raw
history blame
22.9 kB
# summarizer_tool.py
# --- Core Imports ---
import os
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, \
AutoModelForSpeechSeq2Seq, AutoModelForImageClassification, AutoModelForObjectDetection
from PIL import Image
from pydub import AudioSegment
import soundfile as sf
import numpy as np
import io
import logging
import re
import tempfile
import json # Added for handling JSON output consistently
# --- Langchain Imports ---
# Ensure these are correct based on Langchain's modularization
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter # This one is still in langchain
from langchain_community.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
# --- Other Imports ---
from gtts import gTTS
from datasets import load_dataset, Audio # Added for dataset loading
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Global Cache for Pipelines ---
# This prevents reloading the same model multiple times
_pipeline_cache = {}
def get_pipeline(task_name, model_name=None, **kwargs):
"""
Retrieves a Hugging Face pipeline, caching it for efficiency.
"""
# Create a unique key for the cache based on task, model, and kwargs
cache_key = f"{task_name}-{model_name}-{hash(frozenset(kwargs.items()))}"
if cache_key not in _pipeline_cache:
logging.info(f"Loading pipeline for task '{task_name}' with model '{model_name}'...")
if model_name:
_pipeline_cache[cache_key] = pipeline(task_name, model=model_name, **kwargs)
else:
_pipeline_cache[cache_key] = pipeline(task_name, **kwargs) # Uses default model for task
logging.info(f"Pipeline '{task_name}' loaded.")
return _pipeline_cache[cache_key]
# --- Main Dispatcher Class ---
class AllInOneDispatcher:
def __init__(self):
logging.info("Initializing AllInOneDispatcher...")
self.memory = [] # For storing interaction history (optional)
# Define default models for various tasks.
# These will be loaded on demand via get_pipeline.
self.default_models = {
"sentiment-analysis": "distilbert-base-uncased-finetuned-sst-2-english",
"summarization": "sshleifer/distilbart-cnn-12-6",
"text-generation": "gpt2",
"translation_en_to_fr": "Helsinki-NLP/opus-mt-en-fr",
"image-classification": "google/vit-base-patch16-224",
"object-detection": "facebook/detr-resnet-50",
"automatic-speech-recognition": "openai/whisper-tiny.en", # For English ASR
# Add other models/tasks as needed
}
logging.info("AllInOneDispatcher initialized.")
def _get_task_pipeline(self, task: str, model_name: str = None):
"""Helper to get a cached pipeline for a specific task."""
final_model_name = model_name if model_name else self.default_models.get(task)
if not final_model_name:
raise ValueError(f"No default model specified for task '{task}'. Please provide `model_name` or add to default_models.")
return get_pipeline(task, model_name=final_model_name)
def _is_file(self, path):
"""Checks if the given path exists and is a file."""
return os.path.exists(path) and os.path.isfile(path)
def handle_text(self, text: str, task: str = "sentiment-analysis", **kwargs):
"""Processes text input for a given NLP task."""
if not isinstance(text, str):
raise TypeError("Text input must be a string.")
logging.info(f"Handling text for task: {task}")
pipeline_obj = self._get_task_pipeline(task)
result = pipeline_obj(text, **kwargs)
self.memory.append({"task": task, "input": text, "output": result})
return result
def handle_image(self, path: str, task: str = "image-classification", **kwargs):
"""Processes image file input for a given computer vision task."""
if not self._is_file(path):
raise FileNotFoundError(f"Image file not found: {path}")
logging.info(f"Handling image for task: {task}")
try:
image = Image.open(path)
except Exception as e:
raise ValueError(f"Could not open image file: {e}")
pipeline_obj = self._get_task_pipeline(task)
result = pipeline_obj(image, **kwargs)
self.memory.append({"task": task, "input": path, "output": result})
return result
def handle_audio(self, path: str, task: str = "automatic-speech-recognition", **kwargs):
"""Processes audio file input for a given audio task."""
if not self._is_file(path):
raise FileNotFoundError(f"Audio file not found: {path}")
logging.info(f"Handling audio for task: {task}")
# Whisper models expect audio in a specific format (16kHz, mono, float32)
try:
audio = AudioSegment.from_file(path)
audio = audio.set_channels(1).set_frame_rate(16000) # Convert to mono, 16kHz
buffer = io.BytesIO()
audio.export(buffer, format="wav") # Export to WAV in memory
buffer.seek(0) # Rewind buffer
array, sampling_rate = sf.read(buffer) # Read with soundfile
if array.dtype != np.float32:
array = array.astype(np.float32) # Ensure float32
except Exception as e:
logging.error(f"Error preparing audio file for processing: {e}")
raise ValueError(f"Could not prepare audio file: {e}. Ensure ffmpeg is installed system-wide.")
pipeline_obj = self._get_task_pipeline(task)
result = pipeline_obj(array.tolist(), sampling_rate=sampling_rate, **kwargs)
self.memory.append({"task": task, "input": path, "output": result})
return result
def handle_video(self, path: str):
"""
Processes video file input. This is a limited implementation:
Extracts first few frames for image analysis and audio for ASR.
Requires OpenCV (cv2) and system-wide ffmpeg.
"""
if not self._is_file(path):
raise FileNotFoundError(f"Video file not found: {path}")
logging.info(f"Handling video: {path}")
try:
import cv2
except ImportError:
raise ImportError("OpenCV (cv2) not installed. Install with `pip install opencv-python` for video processing.")
frames = []
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file: {path}")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) # Convert BGR to RGB for PIL
if len(frames) >= 5: break # Process only first 5 frames for efficiency
cap.release()
# Extract audio from video
audio_temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
try:
# Using os.system for ffmpeg call requires ffmpeg to be in PATH
# This is a common way but can be less robust than a Python wrapper.
# Hugging Face Spaces typically has ffmpeg.
os.system(f"ffmpeg -i \"{path}\" -q:a 0 -map a \"{audio_temp_path}\" -y")
if not os.path.exists(audio_temp_path) or os.path.getsize(audio_temp_path) == 0:
raise RuntimeError("FFmpeg failed to extract audio or extracted empty audio.")
except Exception as e:
logging.error(f"FFmpeg audio extraction failed: {e}")
audio_temp_path = None # Indicate failure
image_result = None
audio_result = None
if frames:
try:
# Process the first frame for image classification
image_result = self.handle_image(frames[0], task="image-classification")
except Exception as e:
logging.warning(f"Failed to process video frame for image classification: {e}")
if audio_temp_path:
try:
# Process the extracted audio for ASR
audio_result = self.handle_audio(audio_temp_path, task="automatic-speech-recognition")
except Exception as e:
logging.warning(f"Failed to process extracted audio from video: {e}")
finally:
if os.path.exists(audio_temp_path):
os.remove(audio_temp_path) # Clean up temp audio file
result = {"image_analysis": image_result, "audio_analysis": audio_result}
self.memory.append({"task": "video_analysis", "input": path, "output": result})
return result
def handle_pdf(self, path: str):
"""Processes PDF file for summarization using RAG."""
if not self._is_file(path):
raise FileNotFoundError(f"PDF file not found: {path}")
logging.info(f"Handling PDF: {path}")
# RAG components
try:
loader = PyPDFLoader(path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
split_docs = splitter.split_documents(docs)
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.from_documents(split_docs, embeddings)
# Using a text-generation pipeline as the LLM for RetrievalQA
qa_llm = self._get_task_pipeline("text-generation", model_name="gpt2") # Using a smaller model for RAG LLM
qa_chain = RetrievalQA.from_chain_type(llm=qa_llm, retriever=vectorstore.as_retriever())
result = qa_chain.run("Summarize this document")
self.memory.append({"task": "pdf_summarization", "input": path, "output": result})
return result
except Exception as e:
logging.error(f"Error processing PDF: {e}")
raise ValueError(f"Could not process PDF: {e}. Ensure PDF is valid and Langchain dependencies are met.")
def handle_tts(self, text: str, lang: str = 'en'):
"""Converts text to speech and returns the path to the generated audio file."""
if not isinstance(text, str):
raise TypeError("Text input for TTS must be a string.")
logging.info(f"Handling TTS for text: '{text[:50]}...'")
tts = gTTS(text=text, lang=lang)
temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
tts.save(temp_path)
self.memory.append({"task": "tts", "input": text, "output": temp_path})
return temp_path
def process_dataset_from_hub(self, dataset_name: str, subset_name: str, split: str, column_to_process: str, task: str, num_samples: int = 5):
"""
Loads a dataset from Hugging Face Hub, processes a specified column
for a given task, and returns results for a limited number of samples.
"""
logging.info(f"Attempting to load dataset '{dataset_name}' (subset: {subset_name}, split: {split})...")
try:
# Load dataset. Using streaming=True for potentially very large datasets
# and then taking a few examples. trust_remote_code is important for some datasets.
if subset_name.strip():
dataset = load_dataset(dataset_name, subset_name, split=split, streaming=True, trust_remote_code=True)
else:
dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
logging.info(f"Dataset '{dataset_name}' loaded. Processing {num_samples} samples from column '{column_to_process}' for task '{task}'.")
processed_results = []
for i, example in enumerate(dataset):
if i >= num_samples:
break # Stop after processing desired number of samples
if column_to_process not in example:
processed_results.append({
"sample_index": i,
"status": "skipped",
"reason": f"Column '{column_to_process}' not found in this sample."
})
continue
input_data_for_processing = example[column_to_process]
temp_file_to_clean = None # To track temporary files for cleanup
# Determine the actual data type and prepare for self.process
# Hugging Face datasets often load audio/image as specific objects/dicts
if isinstance(input_data_for_processing, str):
# It's already a string, assume text or a path
pass
elif isinstance(input_data_for_processing, dict) and 'array' in input_data_for_processing and 'sampling_rate' in input_data_for_processing:
# This is an audio object from datasets library
# Save to a temporary WAV file for self.handle_audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
sf.write(tmp_audio.name, input_data_for_processing['array'], input_data_for_processing['sampling_rate'])
input_data_for_processing = tmp_audio.name
temp_file_to_clean = tmp_audio.name
elif isinstance(input_data_for_processing, Image.Image):
# This is a PIL Image object from datasets library
# Save to a temporary PNG file for self.handle_image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image:
input_data_for_processing.save(tmp_image.name)
input_data_for_processing = tmp_image.name
temp_file_to_clean = tmp_image.name
else:
processed_results.append({
"sample_index": i,
"status": "error",
"reason": f"Unsupported data type in column '{column_to_process}': {type(input_data_for_processing)}"
})
continue # Skip to next sample
try:
# Call the general process method of the dispatcher
single_result = self.process(input_data_for_processing, task=task)
processed_results.append({
"sample_index": i,
"original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
"processed_result": single_result
})
except Exception as e:
logging.error(f"Error processing sample {i} from dataset: {e}")
processed_results.append({
"sample_index": i,
"original_content": example[column_to_process] if isinstance(example[column_to_process], str) else f"<{type(example[column_to_process]).__name__} object>",
"status": "error",
"reason": str(e)
})
finally:
if temp_file_to_clean and os.path.exists(temp_file_to_clean):
os.remove(temp_file_to_clean) # Clean up temporary file
return processed_results
except Exception as e:
logging.error(f"Failed to load or iterate dataset: {e}")
return [{"error": f"Failed to load or process dataset: {e}"}]
def process(self, input_data, task=None, **kwargs):
"""
Main entry point for the AI tool. Tries to determine input type and
dispatches to the appropriate processing function.
Args:
input_data: Can be raw text (str) or a file path (str) for image/audio/video/pdf.
task (str, optional): The specific AI task to perform.
Required for non-text inputs.
For text, it defaults to "sentiment-analysis".
**kwargs: Additional arguments to pass to the specific handler or pipeline.
Returns:
The result from the AI model, or a file path for TTS.
"""
if not isinstance(input_data, str):
raise TypeError("Input data must be a string (raw text or file path).")
if self._is_file(input_data):
file_extension = input_data.split('.')[-1].lower()
if file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff']:
if not task: task = "image-classification"
return self.handle_image(input_data, task=task, **kwargs)
elif file_extension in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
if not task: task = "automatic-speech-recognition"
return self.handle_audio(input_data, task=task, **kwargs)
elif file_extension in ['mp4', 'mov', 'avi', 'mkv']:
# Video processing is a separate, more complex handler
return self.handle_video(input_data)
elif file_extension == 'pdf':
return self.handle_pdf(input_data)
else:
raise ValueError(f"Unsupported file type: .{file_extension}. Or specify task for this file.")
else:
# Assume it's raw text if not a file path
if task == "tts":
return self.handle_tts(input_data, **kwargs)
if not task: task = "sentiment-analysis" # Default text task
return self.handle_text(input_data, task=task, **kwargs)
# --- Example Usage (for local testing only - will be skipped when imported by app.py) ---
if __name__ == "__main__":
logging.info("Running local example usage of AllInOneDispatcher.")
dispatcher = AllInOneDispatcher()
# Text Examples
print("\n--- Text Examples ---")
text_input = "The new movie was absolutely fantastic! I highly recommend it."
print(f"Input: '{text_input}'")
print(f"Sentiment: {dispatcher.process(text_input, task='sentiment-analysis')}")
text_to_summarize = ("Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—demonstrated by machines, as opposed to intelligence displayed by animals or humans. Example tasks in which AI is used include speech recognition, computer vision, translation, and others. AI applications include advanced web search engines, recommendation systems, understanding human speech, self-driving cars, and competing at the highest level in strategic game systems.")
print(f"\nInput: '{text_to_summarize}'")
print(f"Summary: {dispatcher.process(text_to_summarize, task='summarization', max_length=50, min_length=10)}")
text_to_generate = "In a galaxy far, far away, a brave knight"
print(f"\nInput: '{text_to_generate}'")
generated_text = dispatcher.process(text_to_generate, task='text-generation', max_new_tokens=50, num_return_sequences=1)
print(f"Generated Text: {generated_text[0]['generated_text']}")
tts_text = "Hello, this is a test from the AI assistant."
tts_path = dispatcher.process(tts_text, task="tts", lang="en")
print(f"TTS audio saved to: {tts_path}")
if os.path.exists(tts_path):
os.remove(tts_path) # Clean up generated audio
# Image Examples (requires dummy image or real path)
dummy_image_path = "dummy_image_for_test.png"
if not os.path.exists(dummy_image_path):
try:
Image.new('RGB', (100, 100), color='blue').save(dummy_image_path)
print(f"\nCreated dummy image: {dummy_image_path}")
except Exception as e:
print(f"\nCould not create dummy image: {e}. Skipping image example.")
dummy_image_path = None
if dummy_image_path and os.path.exists(dummy_image_path):
print(f"\nImage Input: {dummy_image_path}")
try:
print(f"Image Classification: {dispatcher.process(dummy_image_path, task='image-classification')}")
except Exception as e:
print(f"Error during image classification: {e}")
finally:
os.remove(dummy_image_path)
# Audio Examples (requires dummy audio or real path, and ffmpeg)
dummy_audio_path = "dummy_audio_for_test.wav"
if not os.path.exists(dummy_audio_path):
try:
from pydub.generators import Sine
sine_wave = Sine(440).to_audio_segment(duration=1000)
sine_wave.export(dummy_audio_path, format="wav")
print(f"\nCreated dummy audio: {dummy_audio_path}")
except ImportError:
print("\npydub not installed. Skipping dummy audio creation.")
dummy_audio_path = None
except Exception as e:
print(f"\nCould not create dummy audio: {e}. Skipping audio example.")
dummy_audio_path = None
if dummy_audio_path and os.path.exists(dummy_audio_path):
print(f"\nAudio Input: {dummy_audio_path}")
try:
transcription = dispatcher.process(dummy_audio_path, task='automatic-speech-recognition')
print(f"Audio Transcription: {transcription['text']}")
except Exception as e:
print(f"Error during audio transcription: {e}")
finally:
os.remove(dummy_audio_path)
# PDF Example (requires a dummy PDF or real path)
# Note: Creating a dummy PDF programmatically is complex.
# For testing, you'd need to place a small PDF file in the same directory.
# dummy_pdf_path = "dummy.pdf"
# if os.path.exists(dummy_pdf_path):
# print(f"\nPDF Input: {dummy_pdf_path}")
# try:
# print(f"PDF Summary: {dispatcher.process(dummy_pdf_path, task='pdf')}")
# except Exception as e:
# print(f"Error during PDF processing: {e}")
# else:
# print(f"\nSkipping PDF example: '{dummy_pdf_path}' not found. Please create one for testing.")
# Dataset Example (requires internet access)
print("\n--- Dataset Example (will process a few samples) ---")
try:
dataset_results = dispatcher.process_dataset_from_hub(
dataset_name="glue",
subset_name="sst2",
split="train",
column_to_process="sentence",
task="sentiment-analysis",
num_samples=2
)
print(f"Dataset Processing Results: {json.dumps(dataset_results, indent=2)}")
except Exception as e:
print(f"Error during dataset processing example: {e}")
logging.info("Local example usage complete.")