Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import spacy | |
| import tempfile | |
| import glob | |
| import yt_dlp | |
| import shutil | |
| import cv2 | |
| import librosa | |
| import wikipedia | |
| from typing import TypedDict, List, Optional, Dict, Any | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage # If you are using it | |
| from langchain_community.retrievers import BM25Retriever # If you are using it | |
| from langgraph.prebuilt import ToolNode, tools_condition # If you are using it | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.schema import Document | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
| from io import BytesIO | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| import re | |
| from PIL import Image # This is correctly imported, but was being used incorrectly | |
| import numpy as np | |
| from collections import Counter | |
| import torch | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline | |
| from typing import TypedDict, List, Optional, Dict, Any, Literal, Tuple | |
| from langgraph.graph import StateGraph, START, END | |
| from langchain.docstore.document import Document | |
| nlp = spacy.load("en_core_web_sm") | |
| # Define file extension sets for each category | |
| PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
| AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
| CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
| SPREADSHEET_EXTENSIONS = { | |
| '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
| '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
| '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
| '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
| '.pmv', '.uos', '.txt' | |
| } | |
| def get_file_type(filename: str) -> str: | |
| if not filename or '.' not in filename or filename == '': | |
| return '' | |
| ext = filename.lower().rsplit('.', 1)[-1] | |
| dot_ext = f'.{ext}' | |
| if dot_ext in PICTURE_EXTENSIONS: | |
| return 'picture' | |
| elif dot_ext in AUDIO_EXTENSIONS: | |
| return 'audio' | |
| elif dot_ext in CODE_EXTENSIONS: | |
| return 'code' | |
| elif dot_ext in SPREADSHEET_EXTENSIONS: | |
| return 'spreadsheet' | |
| else: | |
| return 'unknown' | |
| def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
| """ | |
| Writes bytes to a file in the system temporary directory using the provided file_name. | |
| Returns the full path to the saved file. | |
| The file will persist until manually deleted or the OS cleans the temp directory. | |
| """ | |
| temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_path = os.path.join(temp_dir, file_name) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_bytes) | |
| print(f"File written to: {file_path}") | |
| return file_path | |
| # 1. Define the State type | |
| class State(TypedDict, total=False): | |
| question: str | |
| task_id: str | |
| input_file: bytes | |
| file_type: str | |
| context: List[Document] # Using LangChain's Document class | |
| file_path: Optional[str] | |
| youtube_url: Optional[str] | |
| answer: Optional[str] | |
| frame_answers: Optional[list] | |
| next: Optional[str] # Added to track the next node | |
| # --- LLM pipeline for general questions --- | |
| llm_pipe = pipeline("text-generation", | |
| #model="meta-llama/Llama-3.3-70B-Instruct", | |
| #model="meta-llama/Meta-Llama-3-8B-Instruct", | |
| #model="Qwen/Qwen2-7B-Instruct", | |
| #model="microsoft/Phi-4-reasoning", | |
| model="microsoft/Phi-3-mini-4k-instruct", | |
| device_map="auto", | |
| #device_map={ "": 0 }, # "" means the whole model | |
| #max_memory={0: "10GiB"}, | |
| torch_dtype="auto", | |
| max_new_tokens=256) | |
| # Speech-to-text pipeline | |
| asr_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| device=-1 | |
| #device_map={"", 0}, | |
| #max_memory = {0: "4.5GiB"}, | |
| #device_map="auto" | |
| ) | |
| # --- Your BLIP VQA setup --- | |
| #device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = "cpu" | |
| vqa_model_name = "Salesforce/blip-vqa-base" | |
| processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) | |
| # Attempt to load model to GPU; fall back to CPU if OOM | |
| try: | |
| model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
| except torch.cuda.OutOfMemoryError: | |
| print("WARNING: Loading model to CPU due to insufficient GPU memory.") | |
| device = "cpu" # Switch device to CPU | |
| model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to(device) | |
| # --- Helper: Answer question on a single frame --- | |
| def answer_question_on_frame(image_path, question): | |
| # Fixed: Properly use the PIL Image module | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
| out = model_vqa.generate(**inputs) | |
| answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| # --- Helper: Answer question about the whole video --- | |
| def answer_video_question(frames_dir, question): | |
| valid_exts = ('.jpg', '.jpeg', '.png') | |
| # Check if directory exists | |
| if not os.path.exists(frames_dir): | |
| return { | |
| "most_common_answer": "No frames found to analyze.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
| if f.lower().endswith(valid_exts)] | |
| # Sort frames properly by number | |
| def get_frame_number(filename): | |
| match = re.search(r'(\d+)', os.path.basename(filename)) | |
| return int(match.group(1)) if match else 0 | |
| frame_files = sorted(frame_files, key=get_frame_number) | |
| if not frame_files: | |
| return { | |
| "most_common_answer": "No valid image frames found.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| answers = [] | |
| for frame_path in frame_files: | |
| try: | |
| ans = answer_question_on_frame(frame_path, question) | |
| answers.append(ans) | |
| print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
| except Exception as e: | |
| print(f"Error processing frame {frame_path}: {str(e)}") | |
| if not answers: | |
| return { | |
| "most_common_answer": "Could not analyze any frames successfully.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| counted = Counter(answers) | |
| most_common_answer, freq = counted.most_common(1)[0] | |
| return { | |
| "most_common_answer": most_common_answer, | |
| "all_answers": answers, | |
| "answer_counts": counted | |
| } | |
| def download_youtube_video(url, output_dir='/content/video/', output_filename='downloaded_video.mp4'): | |
| # Ensure the output directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Delete all files in the output directory | |
| files = glob.glob(os.path.join(output_dir, '*')) | |
| for f in files: | |
| try: | |
| os.remove(f) | |
| except Exception as e: | |
| print(f"Error deleting {f}: {str(e)}") | |
| # Set output path for yt-dlp | |
| output_path = os.path.join(output_dir, output_filename) | |
| ydl_opts = { | |
| 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
| 'outtmpl': output_path, | |
| 'quiet': True, | |
| 'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegVideoConvertor', | |
| 'preferedformat': 'mp4', # Recode if needed | |
| }] | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([url]) | |
| return output_path | |
| # --- Helper: Extract frames from video --- | |
| def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
| # --- Clean output directory before extracting new frames --- | |
| if os.path.exists(output_dir): | |
| for filename in os.listdir(output_dir): | |
| file_path = os.path.join(output_dir, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| else: | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return False | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * frame_interval_seconds) | |
| count = 0 | |
| saved = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % frame_interval == 0: | |
| frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
| cv2.imwrite(frame_filename, frame) | |
| saved += 1 | |
| count += 1 | |
| cap.release() | |
| print(f"Extracted {saved} frames.") | |
| return saved > 0 | |
| except Exception as e: | |
| print(f"Exception during frame extraction: {e}") | |
| return False | |
| def image_qa(image_path: str, question: str, model_name: str = vqa_model_name) -> str: | |
| """ | |
| Answers questions about images using Hugging Face's VQA pipeline. | |
| Args: | |
| image_path: Path to local image file or URL | |
| question: Natural language question about the image | |
| model_name: Pretrained VQA model (default: good general-purpose model) | |
| Returns: | |
| str: The model's best answer | |
| """ | |
| # Create VQA pipeline with specified model | |
| vqa_pipeline = pipeline("visual-question-answering", model=model_name) | |
| # Get predictions (automatically handles local files/URLs) | |
| results = vqa_pipeline(image=image_path, question=question, top_k=1) | |
| # Return top answer | |
| return results[0]['answer'] | |
| def router(state: Dict[str, Any]) -> str: | |
| """Determine the next node based on whether the question contains a YouTube URL or references Wikipedia.""" | |
| question = state.get('question', '') | |
| # Pattern for Wikipedia and similar sources | |
| wiki_pattern = r"(wikipedia\.org|wiki|encyclopedia|britannica\.com|encyclop[a|æ]dia)" | |
| has_wiki = re.search(wiki_pattern, question, re.IGNORECASE) is not None | |
| # Pattern for YouTube | |
| yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
| has_youtube = re.search(yt_pattern, question) is not None | |
| # Check for image | |
| has_image = state.get('file_type') == 'picture' | |
| # Check for audio | |
| has_audio = state.get('file_type') == 'audio' | |
| print(f"Has Wikipedia reference: {has_wiki}") | |
| print(f"Has YouTube link: {has_youtube}") | |
| print(f"Has picture file: {has_image}") | |
| print(f"Has audio file: {has_audio}") | |
| if has_wiki: | |
| return "retrieve" | |
| elif has_youtube: | |
| # Store the extracted YouTube URL in the state | |
| url_match = re.search(r"(https?://[^\s]+)", question) | |
| if url_match: | |
| state['youtube_url'] = url_match.group(0) | |
| return "video" | |
| elif has_image: | |
| return "image" | |
| elif has_audio: | |
| return "audio" | |
| else: | |
| return "llm" | |
| # --- Node Implementation --- | |
| def node_image(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Router node that decides which node to go to next.""" | |
| print("Running node_image") | |
| # Add the next state to the state dict | |
| img = Image.open(state['file_path']) | |
| state['answer'] = image_qa(state['file_path'], state['question']) | |
| return state | |
| def node_decide(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Router node that decides which node to go to next.""" | |
| print("Running node_decide") | |
| # Add the next state to the state dict | |
| state["next"] = router(state) | |
| print(f"Routing to: {state['next']}") | |
| return state | |
| def node_video(state: Dict[str, Any]) -> Dict[str, Any]: | |
| print("Running node_video") | |
| youtube_url = state.get('youtube_url') | |
| if not youtube_url: | |
| state['answer'] = "No YouTube URL found in the question." | |
| return state | |
| question = state['question'] | |
| # Extract the actual question part (remove the URL) | |
| question_text = re.sub(r'https?://[^\s]+', '', question).strip() | |
| if not question_text.endswith('?'): | |
| question_text += '?' | |
| video_file = download_youtube_video(youtube_url) | |
| if not video_file or not os.path.exists(video_file): | |
| state['answer'] = "Failed to download the video." | |
| return state | |
| frames_dir = "/tmp/frames" | |
| os.makedirs(frames_dir, exist_ok=True) | |
| success = extract_frames(video_path=video_file, output_dir=frames_dir, frame_interval_seconds=10) | |
| if not success: | |
| state['answer'] = "Failed to extract frames from the video." | |
| return state | |
| result = answer_video_question(frames_dir, question_text) | |
| state['answer'] = result['most_common_answer'] | |
| state['frame_answers'] = result['all_answers'] | |
| # Create Document objects for each frame analysis | |
| frame_documents = [] | |
| for i, ans in enumerate(result['all_answers']): | |
| doc = Document( | |
| page_content=f"Frame {i}: {ans}", | |
| metadata={"frame_number": i, "source": "video_analysis"} | |
| ) | |
| frame_documents.append(doc) | |
| # Add documents to state if not already present | |
| if 'context' not in state: | |
| state['context'] = [] | |
| state['context'].extend(frame_documents) | |
| print(f"Video answer: {state['answer']}") | |
| return state | |
| def node_audio_rag(state: Dict[str, Any]) -> Dict[str, Any]: | |
| print(f"Processing audio file: {state['file_path']}") | |
| try: | |
| # Step 1: Transcribe audio | |
| audio, sr = librosa.load(state['file_path'], sr=16000) | |
| asr_result = asr_pipe({"raw": audio, "sampling_rate": sr}) | |
| audio_transcript = asr_result['text'] | |
| print(f"Audio transcript: {audio_transcript}") | |
| # Step 2: Store ONLY the transcript in the vector store | |
| transcript_doc = [Document(page_content=audio_transcript)] | |
| embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5') | |
| vector_db = FAISS.from_documents(transcript_doc, embedding=embeddings) | |
| # Step 3: Retrieve relevant docs for the user's question | |
| question = state['question'] | |
| similar_docs = vector_db.similarity_search(question, k=1) # Only one doc in store | |
| retrieved_context = "\n".join([doc.page_content for doc in similar_docs]) | |
| # Step 4: Augment prompt and generate answer | |
| prompt = ( | |
| f"Use the following context to answer the question.\n" | |
| f"Context:\n{retrieved_context}\n\n" | |
| f"Question: {question}\nAnswer:" | |
| ) | |
| llm_response = llm_pipe(prompt) | |
| state['answer'] = llm_response[0]['generated_text'] | |
| except Exception as e: | |
| error_msg = f"Audio processing error: {str(e)}" | |
| print(error_msg) | |
| state['answer'] = error_msg | |
| return state | |
| def node_llm(state: Dict[str, Any]) -> Dict[str, Any]: | |
| print("Running node_llm") | |
| question = state['question'] | |
| # Optionally add context from state (e.g., Wikipedia/Wikidata content) | |
| context_text = "" | |
| if 'article_content' in state and state['article_content']: | |
| context_text = f"\n\nBackground Information:\n{state['article_content']}\n" | |
| elif 'context' in state and state['context']: | |
| context_text = "\n\n".join([doc.page_content for doc in state['context']]) | |
| # Compose a detailed prompt | |
| prompt = ( | |
| "You are an expert researcher. Answer the user's question as accurately as possible. " | |
| "If the text appears to be scrambled, try to unscramble the text for the user" | |
| "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. " | |
| "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step.\n\n" | |
| f"Question: {question}" | |
| f"{context_text}\n" | |
| "Answer:" | |
| ) | |
| # Add document to state for traceability | |
| query_doc = Document( | |
| page_content=prompt, | |
| metadata={"source": "llm_prompt"} | |
| ) | |
| if 'context' not in state: | |
| state['context'] = [] | |
| state['context'].append(query_doc) | |
| try: | |
| result = llm_pipe(prompt) | |
| state['answer'] = result[0]['generated_text'] | |
| except Exception as e: | |
| print(f"Error in LLM processing: {str(e)}") | |
| state['answer'] = f"An error occurred while processing your question: {str(e)}" | |
| print(f"LLM answer: {state['answer']}") | |
| return state | |
| # --- Define the edge condition function --- | |
| def get_next_node(state: Dict[str, Any]) -> str: | |
| """Get the next node from the state.""" | |
| return state["next"] | |
| # 2. Improved Wikipedia Retrieval Node | |
| def extract_keywords(question: str) -> List[str]: | |
| doc = nlp(question) | |
| keywords = [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] # Extract proper nouns and nouns | |
| return keywords | |
| def extract_entities(question: str) -> List[str]: | |
| doc = nlp(question) | |
| entities = [ent.text for ent in doc.ents] | |
| return entities if entities else [token.text for token in doc if token.pos_ in ("PROPN", "NOUN")] | |
| def retrieve(state: State) -> dict: | |
| keywords = extract_entities(state["question"]) | |
| query = " ".join(keywords) | |
| search_results = wikipedia.search(query) | |
| selected_page = search_results[0] if search_results else None | |
| if selected_page: | |
| loader = WikipediaLoader( | |
| query=selected_page, | |
| lang="en", | |
| load_max_docs=1, | |
| doc_content_chars_max=100000, | |
| load_all_available_meta=True | |
| ) | |
| docs = loader.load() | |
| # Chunk the article for finer retrieval | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) | |
| all_chunks = [] | |
| for doc in docs: | |
| chunks = splitter.split_text(doc.page_content) | |
| all_chunks.extend([Document(page_content=chunk) for chunk in chunks]) | |
| # Optionally: re-rank or filter chunks here | |
| return {"context": all_chunks} | |
| else: | |
| return {"context": []} | |
| # 3. Prompt Template for General QA | |
| prompt = PromptTemplate( | |
| input_variables=["question", "context"], | |
| template=( | |
| "You are an expert researcher. Given the following context from Wikipedia, answer the user's question as accurately as possible. " | |
| "If the text appears to be scrambled, try to unscramble the text for the user" | |
| "If the information is incomplete or ambiguous, provide your best estimate based on the available evidence, and clearly explain any assumptions or reasoning you use. " | |
| "If the answer requires multiple steps or deeper analysis, break down the question into sub-questions and answer them step by step, citing the relevant context for each step." | |
| "Context:\n{context}\n\n" | |
| "Question: {question}\n\n" | |
| "Best Estimate Answer:" | |
| ) | |
| ) | |
| """ | |
| def generate(state: State) -> dict: | |
| # Concatenate all context documents into a single string | |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
| # Format the prompt for the LLM | |
| prompt_str = prompt.format(question=state["question"], context=docs_content) | |
| # Generate answer | |
| response = llm.invoke(prompt_str) | |
| return {"answer": response} | |
| """ | |
| def generate(state: dict) -> dict: | |
| # Concatenate all context documents into a single string | |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
| # Format the prompt for the LLM | |
| prompt_str = prompt.format(question=state["question"], context=docs_content) | |
| # Generate answer using Hugging Face pipeline | |
| response = llm_pipe(prompt_str) | |
| # Extract generated text | |
| answer = response[0]["generated_text"] | |
| return {"answer": answer} | |
| # Create the StateGraph | |
| graph = StateGraph(State) | |
| # Add nodes | |
| graph.add_node("decide", node_decide) | |
| graph.add_node("video", node_video) | |
| graph.add_node("llm", node_llm) | |
| graph.add_node("retrieve", retrieve) | |
| graph.add_node("generate", generate) | |
| graph.add_node("image", node_image) | |
| graph.add_node("audio", node_audio_rag) | |
| # Add edge from START to decide | |
| graph.add_edge(START, "decide") | |
| graph.add_edge("retrieve", "generate") | |
| # Add conditional edges from decide to video or llm based on question | |
| graph.add_conditional_edges( | |
| "decide", | |
| get_next_node, | |
| { | |
| "video": "video", | |
| "llm": "llm", | |
| "retrieve": "retrieve", | |
| "image": "image", | |
| "audio": "audio" | |
| } | |
| ) | |
| # Add edges from video and llm to END to terminate the graph | |
| graph.add_edge("video", END) | |
| graph.add_edge("llm", END) | |
| graph.add_edge("generate", END) | |
| graph.add_edge("image", END) | |
| graph.add_edge("audio", END) | |
| # Compile the graph | |
| agent = graph.compile() | |
| # --- Usage Example --- | |
| def intelligent_agent(state: State) -> str: | |
| """Process a question using the appropriate pipeline based on content.""" | |
| #state = State(question= question) | |
| try: | |
| final_state = agent.invoke(state) | |
| return final_state.get('answer', "No answer found.") | |
| except Exception as e: | |
| print(f"Error in agent execution: {str(e)}") | |
| return f"An error occurred: {str(e)}" | |