Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import spacy | |
| import tempfile | |
| import glob | |
| import yt_dlp | |
| import shutil | |
| import cv2 | |
| import librosa | |
| import wikipedia | |
| from typing import TypedDict, List, Optional, Dict, Any | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it | |
| from langchain_community.retrievers import BM25Retriever # If you are using it | |
| from langgraph.prebuilt import ToolNode, tools_condition # If you are using it | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.schema import Document | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
| from io import BytesIO | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration | |
| import os | |
| import re | |
| from PIL import Image # This is correctly imported, but was being used incorrectly | |
| import numpy as np | |
| from collections import Counter | |
| import torch | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
| from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain.docstore.document import Document | |
| nlp = spacy.load("en_core_web_sm") | |
| # Define file extension sets for each category | |
| PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
| AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
| CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
| SPREADSHEET_EXTENSIONS = { | |
| '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
| '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
| '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
| '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
| '.pmv', '.uos', '.txt' | |
| } | |
| def get_file_type(filename: str) -> str: | |
| if not filename or '.' not in filename or filename == '': | |
| return '' | |
| ext = filename.lower().rsplit('.', 1)[-1] | |
| dot_ext = f'.{ext}' | |
| if dot_ext in PICTURE_EXTENSIONS: | |
| return 'picture' | |
| elif dot_ext in AUDIO_EXTENSIONS: | |
| return 'audio' | |
| elif dot_ext in CODE_EXTENSIONS: | |
| return 'code' | |
| elif dot_ext in SPREADSHEET_EXTENSIONS: | |
| return 'spreadsheet' | |
| else: | |
| return 'unknown' | |
| def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
| """ | |
| Writes bytes to a file in the system temporary directory using the provided file_name. | |
| Returns the full path to the saved file. | |
| The file will persist until manually deleted or the OS cleans the temp directory. | |
| """ | |
| temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_path = os.path.join(temp_dir, file_name) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_bytes) | |
| print(f"File written to: {file_path}") | |
| return file_path | |
| # 1. Define the State type | |
| class State(TypedDict, total=False): | |
| question: str | |
| task_id: str | |
| input_file: Optional[bytes] | |
| file_type: Optional[str] | |
| context: List[Document] # Using LangChain's Document class | |
| file_path: Optional[str] | |
| youtube_url: Optional[str] | |
| answer: Optional[str] | |
| frame_answers: Optional[list] | |
| next: Optional[str] # Added to track the next node | |
| # --- LLM pipeline for general questions --- | |
| llm_pipe = pipeline( | |
| "text-generation", | |
| model="microsoft/Phi-3-mini-4k-instruct", | |
| device_map="auto", | |
| torch_dtype="auto", | |
| max_new_tokens=256, | |
| trust_remote_code=True | |
| ) | |
| # Initialize RAG components | |
| tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base", trust_remote_code=True) | |
| retriever = RagRetriever.from_pretrained( | |
| "facebook/rag-token-base", | |
| index_name="exact", # or "legacy" for legacy FAISS index | |
| use_dummy_dataset=False, # set to False and download the full index for real Wikipedia retrieval | |
| trust_remote_code=True, # Trust remote code for dataset loading | |
| dataset_revision="main", # Specify a fixed revision | |
| dataset="wiki_dpr", # Explicitly specify dataset name | |
| ) | |
| rag_model = RagSequenceForGeneration.from_pretrained( | |
| "facebook/rag-token-base", | |
| retriever=retriever, | |
| trust_remote_code=True | |
| ) | |
| # Speech-to-text pipeline | |
| asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| device="auto" | |
| ) | |
| # --- BLIP VQA setup --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| vqa_model_name = "Salesforce/blip-vqa-base" | |
| processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) | |
| # Attempt to load model to GPU; fall back to CPU if OOM | |
| try: | |
| model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
| except torch.cuda.OutOfMemoryError: | |
| print("WARNING: Loading model to CPU due to insufficient GPU memory.") | |
| device = "cpu" # Switch device to CPU | |
| model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
| # --- Helper functions --- | |
| def ensure_final_answer_format(answer_text: str) -> str: | |
| """Ensure the answer ends with FINAL ANSWER: format""" | |
| # Check if the answer already contains a FINAL ANSWER section | |
| if "FINAL ANSWER:" in answer_text: | |
| # Extract everything after FINAL ANSWER: | |
| final_answer_part = answer_text.split("FINAL ANSWER:", 1)[1].strip() | |
| return f"FINAL ANSWER: {final_answer_part}" | |
| else: | |
| # If no FINAL ANSWER section exists, wrap the entire answer | |
| return f"FINAL ANSWER: {answer_text.strip()}" | |
| def extract_entities(text: str) -> List[str]: | |
| """Extract key entities from text using spaCy if available, or regex fallback""" | |
| if nlp: | |
| # Using spaCy for better entity extraction | |
| doc = nlp(text) | |
| entities = [ent.text for ent in doc.ents] | |
| keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] | |
| return entities if entities else keywords | |
| else: | |
| # Simple fallback using regex to extract potential keywords | |
| words = text.lower().split() | |
| stopwords = ["what", "who", "when", "where", "why", "how", "is", "are", "the", "a", "an", "of", "in", "on", "at"] | |
| keywords = [word for word in words if word not in stopwords and len(word) > 2] | |
| return keywords | |
| def answer_question_on_frame(image_path, question): | |
| """Answer a question about a single video frame using BLIP""" | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
| out = model_vqa.generate(**inputs) | |
| answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| print(f"Error processing frame {image_path}: {str(e)}") | |
| return "Error processing this frame" | |
| def answer_video_question(frames_dir, question): | |
| """Answer a question about a video by analyzing extracted frames""" | |
| valid_exts = ('.jpg', '.jpeg', '.png') | |
| # Check if directory exists | |
| if not os.path.exists(frames_dir): | |
| return { | |
| "most_common_answer": "No frames found to analyze.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
| if f.lower().endswith(valid_exts)] | |
| # Sort frames properly by number | |
| def get_frame_number(filename): | |
| match = re.search(r'(\d+)', os.path.basename(filename)) | |
| return int(match.group(1)) if match else 0 | |
| frame_files = sorted(frame_files, key=get_frame_number) | |
| if not frame_files: | |
| return { | |
| "most_common_answer": "No valid image frames found.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| answers = [] | |
| for frame_path in frame_files: | |
| try: | |
| ans = answer_question_on_frame(frame_path, question) | |
| answers.append(ans) | |
| print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
| except Exception as e: | |
| print(f"Error processing frame {frame_path}: {str(e)}") | |
| if not answers: | |
| return { | |
| "most_common_answer": "Could not analyze any frames successfully.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| counted = Counter(answers) | |
| most_common_answer, freq = counted.most_common(1)[0] | |
| return { | |
| "most_common_answer": most_common_answer, | |
| "all_answers": answers, | |
| "answer_counts": counted | |
| } | |
| def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): | |
| """Download a YouTube video using yt-dlp""" | |
| # Ensure the output directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Delete all files in the output directory | |
| files = glob.glob(os.path.join(output_dir, '*')) | |
| for f in files: | |
| try: | |
| os.remove(f) | |
| except Exception as e: | |
| print(f"Error deleting {f}: {str(e)}") | |
| # Set output path for yt-dlp | |
| output_path = os.path.join(output_dir, output_filename) | |
| try: | |
| ydl_opts = { | |
| 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
| 'outtmpl': output_path, | |
| 'quiet': True, | |
| 'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegVideoConvertor', | |
| 'preferedformat': 'mp4', # Recode if needed | |
| }] | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([url]) | |
| return output_path | |
| except Exception as e: | |
| print(f"Error downloading YouTube video: {str(e)}") | |
| return None | |
| def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
| """Extract frames from a video file at specified intervals""" | |
| # Clean output directory before extracting new frames | |
| if os.path.exists(output_dir): | |
| for filename in os.listdir(output_dir): | |
| file_path = os.path.join(output_dir, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| else: | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return False | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * frame_interval_seconds) | |
| count = 0 | |
| saved = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % frame_interval == 0: | |
| frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
| cv2.imwrite(frame_filename, frame) | |
| saved += 1 | |
| count += 1 | |
| cap.release() | |
| print(f"Extracted {saved} frames.") | |
| return saved > 0 | |
| except Exception as e: | |
| print(f"Exception during frame extraction: {e}") | |
| return False | |
| def image_qa(image_path: str, question: str) -> str: | |
| """Answer questions about an image using the BLIP model""" | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
| out = model_vqa.generate(**inputs) | |
| answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| print(f"Error in image_qa: {str(e)}") | |
| return f"Error processing image: {str(e)}" | |
| # --- Node functions --- | |
| def router(state: Dict[str, Any]) -> str: | |
| """Determine the next node based on question content and file type""" | |
| question = state.get('question', '') | |
| # Pattern for Wikipedia and similar sources | |
| wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)" | |
| has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None | |
| # Pattern for YouTube | |
| yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
| has_youtube = re.search(yt_pattern, question) is not None | |
| # Check for image | |
| has_image = state.get('file_type') == 'picture' | |
| # Check for audio | |
| has_audio = state.get('file_type') == 'audio' | |
| print(f"Has Wikipedia reference: {has_wiki}") | |
| print(f"Has YouTube link: {has_youtube}") | |
| print(f"Has picture file: {has_image}") | |
| print(f"Has audio file: {has_audio}") | |
| if has_wiki: | |
| return "retrieve" | |
| elif has_youtube: | |
| # Store the extracted YouTube URL in the state | |
| url_match = re.search(r"(https?://[^\s]+)", question) | |
| if url_match: | |
| state['youtube_url'] = url_match.group(0) | |
| return "video" | |
| elif has_image: | |
| return "image" | |
| elif has_audio: | |
| return "audio" | |
| else: | |
| return "llm" | |
| def node_decide(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Router node that decides which node to go to next""" | |
| print("Running node_decide") | |
| # Initialize context list if not present | |
| if 'context' not in state: | |
| state['context'] = [] | |
| # Add the next state to the state dict | |
| state["next"] = router(state) | |
| print(f"Routing to: {state['next']}") | |
| return state | |
| def node_image(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Process image-based questions""" | |
| print("Running node_image") | |
| try: | |
| # Make sure the image file exists | |
| if not os.path.exists(state['file_path']): | |
| state['answer'] = ensure_final_answer_format("Image file not found.") | |
| return state | |
| # Get answer from image QA model | |
| answer = image_qa(state['file_path'], state['question']) | |
| # Format the final answer | |
| state['answer'] = ensure_final_answer_format(answer) | |
| # Add document to state for traceability | |
| image_doc = Document( | |
| page_content=f"Image analysis result: {answer}", | |
| metadata={"source": "image_analysis", "file_path": state['file_path']} | |
| ) | |
| state['context'].append(image_doc) | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| print(error_msg) | |
| state['answer'] = ensure_final_answer_format(error_msg) | |
| return state | |
| def node_video(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Process video-based questions""" | |
| print("Running node_video") | |
| youtube_url = state.get('youtube_url') | |
| if not youtube_url: | |
| state['answer'] = ensure_final_answer_format("No YouTube URL found in the question.") | |
| return state | |
| question = state['question'] | |
| # Extract the actual question part (remove the URL) | |
| question_text = re.sub(r'https?://[^\s]+', '', question).strip() | |
| if not question_text.endswith('?'): | |
| question_text += '?' | |
| video_file = download_youtube_video(youtube_url) | |
| if not video_file or not os.path.exists(video_file): | |
| state['answer'] = ensure_final_answer_format("Failed to download the video.") | |
| return state | |
| frames_dir = "/tmp/frames" | |
| os.makedirs(frames_dir, exist_ok=True) | |
| success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10) | |
| if not success: | |
| state['answer'] = ensure_final_answer_format("Failed to extract frames from the video.") | |
| return state | |
| result = answer_video_question(frames_dir, question_text) | |
| final_answer = result['most_common_answer'] | |
| state['frame_answers'] = result['all_answers'] | |
| # Create Document objects for each frame analysis | |
| frame_documents = [] | |
| for i, ans in enumerate(result['all_answers']): | |
| doc = Document( | |
| page_content=f"Frame {i}: {ans}", | |
| metadata={"frame_number": i, "source": "video_analysis"} | |
| ) | |
| frame_documents.append(doc) | |
| # Add documents to state | |
| state['context'].extend(frame_documents) | |
| state['answer'] = ensure_final_answer_format(final_answer) | |
| print(f"Video answer: {state['answer']}") | |
| return state | |
| def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Process audio-based questions""" | |
| print(f"Processing audio file: {state['file_path']}") | |
| try: | |
| # Step 1: Transcribe audio | |
| audio, sr = librosa.load(state['file_path'], sr=16000) | |
| asr_result = asr_pipe({"raw": audio, "sampling_rate": sr}) | |
| audio_transcript = asr_result['text'] | |
| print(f"Audio transcript: {audio_transcript}") | |
| # Step 2: Store transcript in vector store | |
| transcript_doc = [Document(page_content=audio_transcript)] | |
| embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5') | |
| vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings) | |
| # Step 3: Retrieve relevant docs for the user's question | |
| question = state['question'] | |
| similar_docs = vector_db.similarity_search(question, k=1) | |
| retrieved_context = "\n".join([doc.page_content for doc in similar_docs]) | |
| # Step 4: Generate answer | |
| prompt = ( | |
| f"You are an AI assistant that answers questions about audio content.\n\n" | |
| f"Audio transcript: {retrieved_context}\n\n" | |
| f"Question: {question}\n\n" | |
| f"Based only on the provided audio transcript, answer the question. " | |
| f"If the transcript does not contain relevant information, state that clearly.\n\n" | |
| f"End your response with 'FINAL ANSWER: ' followed by a concise answer." | |
| ) | |
| llm_response = llm_pipe(prompt) | |
| answer_text = llm_response[0]['generated_text'] | |
| # Add documents to state | |
| state['context'].extend(transcript_doc) | |
| state['context'].append(Document( | |
| page_content=prompt, | |
| metadata={"source": "audio_analysis_prompt"} | |
| )) | |
| # Ensure final answer format | |
| state['answer'] = ensure_final_answer_format(answer_text) | |
| except Exception as e: | |
| error_msg = f"Audio processing error: {str(e)}" | |
| print(error_msg) | |
| state['answer'] = ensure_final_answer_format(error_msg) | |
| return state | |
| def node_llm(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Process general knowledge questions with LLM""" | |
| print("Running node_llm") | |
| question = state['question'] | |
| # Compose a detailed prompt | |
| prompt = ( | |
| "You are an AI assistant that answers questions using your general knowledge. " | |
| "Follow these steps:\n\n" | |
| "1. If the question appears to be scrambled or jumbled, first try to unscramble or reconstruct the intended meaning.\n" | |
| "2. Analyze the question (unscrambled if needed) and use your own knowledge to answer it.\n" | |
| "3. If the question can't be answered with certainty, provide your best estimate and clearly explain any assumptions.\n" | |
| "4. Format your answer using these rules:\n" | |
| " - Numbers: Plain digits without commas/units (e.g. 1234567)\n" | |
| " - Strings: Minimal words, no articles/abbreviations\n" | |
| " - Lists: comma-separated values without extra formatting\n\n" | |
| "5. Always conclude with:\n" | |
| "FINAL ANSWER: [your answer] (replace bracketed text)\n\n" | |
| f"Current question: {question}" | |
| ) | |
| # Add document to state for traceability | |
| query_doc = Document( | |
| page_content=prompt, | |
| metadata={"source": "llm_prompt"} | |
| ) | |
| state['context'].append(query_doc) | |
| try: | |
| result = llm_pipe(prompt) | |
| answer_text = result[0]['generated_text'] | |
| state['answer'] = ensure_final_answer_format(answer_text) | |
| except Exception as e: | |
| print(f"Error in LLM processing: {str(e)}") | |
| error_msg = f"An error occurred while processing your question: {str(e)}" | |
| state['answer'] = ensure_final_answer_format(error_msg) | |
| print(f"LLM answer: {state['answer']}") | |
| return state | |
| def retrieve(state: State) -> State: | |
| """Retrieve relevant documents using RAG""" | |
| print("Running retrieve") | |
| question = state["question"] | |
| try: | |
| # Tokenize the question | |
| inputs = tokenizer(question, return_tensors="pt") | |
| # Get doc_ids by using the retriever directly | |
| question_hidden_states = rag_model.question_encoder(inputs["input_ids"])[0] | |
| docs_dict = retriever( | |
| inputs["input_ids"].numpy(), | |
| question_hidden_states.detach().numpy(), | |
| return_tensors="pt" | |
| ) | |
| # Extract the retrieved passages | |
| all_chunks = [] | |
| # Debug print to see what's in docs_dict | |
| print(f"docs_dict keys: {docs_dict.keys()}") | |
| # Check for different possible keys that might contain the documents | |
| doc_text_key = None | |
| for possible_key in ['retrieved_doc_text', 'doc_text', 'texts', 'documents']: | |
| if possible_key in docs_dict: | |
| doc_text_key = possible_key | |
| break | |
| if doc_text_key: | |
| # Access the retrieved document texts from the docs_dict | |
| for i in range(len(docs_dict["doc_ids"][0])): | |
| doc_text = docs_dict[doc_text_key][0][i] | |
| all_chunks.append(Document(page_content=doc_text)) | |
| print(f"Retrieved {len(all_chunks)} documents") | |
| else: | |
| # Fallback: Try to extract document text from doc_ids | |
| doc_ids = docs_dict.get("doc_ids", [[]])[0] | |
| print(f"Retrieved doc_ids: {doc_ids}") | |
| # Create minimal document stubs from IDs | |
| for doc_id in doc_ids: | |
| stub_text = f"Information related to document ID: {doc_id}" | |
| all_chunks.append(Document(page_content=stub_text)) | |
| print(f"Created {len(all_chunks)} document stubs from IDs") | |
| # Add documents to state context | |
| if not state.get('context'): | |
| state['context'] = [] | |
| state['context'].extend(all_chunks) | |
| except Exception as e: | |
| print(f"Error in retrieval: {str(e)}") | |
| # Create an error document | |
| error_doc = Document( | |
| page_content=f"Error during retrieval: {str(e)}", | |
| metadata={"source": "retrieval_error"} | |
| ) | |
| if not state.get('context'): | |
| state['context'] = [] | |
| state['context'].append(error_doc) | |
| return state | |
| def generate(state: State) -> State: | |
| """Generate an answer based on retrieved documents""" | |
| print("Running generate") | |
| try: | |
| # Check if context exists | |
| if not state.get('context') or len(state['context']) == 0: | |
| state['answer'] = ensure_final_answer_format("No relevant information found to answer your question.") | |
| return state | |
| # Concatenate all context documents into a single string | |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
| # Format the prompt for the LLM | |
| prompt_str = PromptTemplate( | |
| input_variables=["question", "context"], | |
| template=( | |
| "You are an AI assistant that answers questions using retrieved context. " | |
| "Follow these steps:\n\n" | |
| "1. Analyze the provided context:\n{context}\n\n" | |
| "2. If the context contains scrambled text, first attempt to reconstruct meaningful information\n" | |
| "3. If the question can't be answered from context alone, combine context with general knowledge " | |
| "but clearly state this limitation\n" | |
| "4. Format your answer using these rules:\n" | |
| " - Numbers: Plain digits without commas/units (e.g. 1234567)\n" | |
| " - Strings: Minimal words, no articles/abbreviations\n" | |
| " - Lists: comma-separated values without extra formatting\n\n" | |
| "5. Always conclude with:\n" | |
| "FINAL ANSWER: [your answer] (replace bracketed text)\n\n" | |
| "Current question: {question}" | |
| ) | |
| ).format(question=state["question"], context=docs_content) | |
| # Generate answer using the LLM pipeline | |
| response = llm_pipe(prompt_str) | |
| answer_text = response[0]["generated_text"] | |
| # Ensure answer has the FINAL ANSWER format | |
| state['answer'] = ensure_final_answer_format(answer_text) | |
| except Exception as e: | |
| print(f"Error in generate node: {str(e)}") | |
| error_msg = f"Error generating answer: {str(e)}" | |
| state['answer'] = ensure_final_answer_format(error_msg) | |
| return state | |
| # --- Define the edge condition function --- | |
| def get_next_node(state: Dict[str, Any]) -> str: | |
| """Get the next node from the state""" | |
| return state["next"] | |
| # Create the StateGraph | |
| graph = StateGraph(State) | |
| # Add nodes | |
| graph.add_node("decide", node_decide) | |
| graph.add_node("video", node_video) | |
| graph.add_node("llm", node_llm) | |
| graph.add_node("retrieve", retrieve) | |
| graph.add_node("generate", generate) | |
| graph.add_node("image", node_image) | |
| graph.add_node("audio", node_audio_rag) | |
| # Add edge from START to decide | |
| graph.add_edge(START, "decide") | |
| graph.add_edge("retrieve", "generate") | |
| # Add conditional edges from decide to other nodes based on question | |
| graph.add_conditional_edges( | |
| "decide", | |
| get_next_node, | |
| { | |
| "video": "video", | |
| "llm": "llm", | |
| "retrieve": "retrieve", | |
| "image": "image", | |
| "audio": "audio" | |
| } | |
| ) | |
| # Add edges from all terminal nodes to END | |
| graph.add_edge("video", END) | |
| graph.add_edge("llm", END) | |
| graph.add_edge("generate", END) | |
| graph.add_edge("image", END) | |
| graph.add_edge("audio", END) | |
| # Compile the graph | |
| agent = graph.compile() | |
| # --- Intelligent Agent Function --- | |
| def intelligent_agent(state: State) -> str: | |
| """Process a question using the appropriate pipeline based on content.""" | |
| try: | |
| # Ensure state has proper structure | |
| if not isinstance(state, dict): | |
| return "FINAL ANSWER: Error - input must be a valid State dictionary" | |
| # Make sure question exists | |
| if 'question' not in state: | |
| return "FINAL ANSWER: Error - question is required" | |
| # Initialize context if not present | |
| if 'context' not in state: | |
| state['context'] = [] | |
| print(f"Processing question: {state['question']}") | |
| # Invoke the agent with the state | |
| final_state = agent.invoke(state) | |
| # Ensure answer has FINAL ANSWER format | |
| answer = final_state.get('answer', "No answer found.") | |
| formatted_answer = ensure_final_answer_format(answer) | |
| return formatted_answer | |
| except Exception as e: | |
| print(f"Error in agent execution: {str(e)}") | |
| return f"FINAL ANSWER: An error occurred - {str(e)}" |