Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| WhisperProcessor, | |
| WhisperForConditionalGeneration, | |
| ) | |
| from auto_gptq import AutoGPTQForCausalLM | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import RetrievalQA | |
| from langchain.schema import LLMResult | |
| from langchain.llms.base import LLM | |
| import torch | |
| import os | |
| import tempfile | |
| import logging | |
| import warnings | |
| from typing import Optional, Dict, Any, List | |
| from gtts import gTTS | |
| import numpy as np | |
| # Suppress warnings and setup logging | |
| warnings.filterwarnings("ignore") | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| print("===== Application Startup =====") | |
| # -------------------------- | |
| # Configuration | |
| # -------------------------- | |
| CONFIG = { | |
| "model_name": "TheBloke/Llama-2-7B-Chat-GPTQ", | |
| "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", | |
| "whisper_model": "openai/whisper-small", | |
| "persist_dir": "./chroma_db", | |
| "max_new_tokens": 200, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "chunk_size": 500, | |
| "chunk_overlap": 50 | |
| } | |
| # -------------------------- | |
| # Import handling with fallbacks | |
| # -------------------------- | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| print("β Using updated HuggingFaceEmbeddings") | |
| except ImportError: | |
| try: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| print("β Using legacy HuggingFaceEmbeddings") | |
| except ImportError: | |
| print("β Failed to import HuggingFaceEmbeddings") | |
| HuggingFaceEmbeddings = None | |
| try: | |
| from langchain_chroma import Chroma | |
| print("β Using updated Chroma") | |
| except ImportError: | |
| try: | |
| from langchain_community.vectorstores import Chroma | |
| print("β Using legacy Chroma") | |
| except ImportError: | |
| print("β Failed to import Chroma") | |
| Chroma = None | |
| # -------------------------- | |
| # Text-to-Speech Manager | |
| # -------------------------- | |
| class TTSManager: | |
| def speak_text_to_audio(text: str, lang: str = 'en') -> Optional[str]: | |
| try: | |
| if not text.strip(): | |
| return None | |
| # Limit text length to avoid TTS issues | |
| text = text[:500] + "..." if len(text) > 500 else text | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: | |
| tts = gTTS(text=text, lang=lang, slow=False) | |
| tts.save(fp.name) | |
| return fp.name | |
| except Exception as e: | |
| logger.error(f"TTS error: {e}") | |
| return None | |
| # -------------------------- | |
| # Model Manager with Robust Error Handling | |
| # -------------------------- | |
| class ModelManager: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.tokenizer = None | |
| self.model = None | |
| self.whisper_processor = None | |
| self.whisper_model = None | |
| self.llm = None | |
| self.model_loaded = False | |
| self.whisper_loaded = False | |
| print(f"Device: {self.device}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| # Check HF token | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| if self.hf_token: | |
| print("β HF_TOKEN found") | |
| else: | |
| print("! HF_TOKEN not found, proceeding without authentication") | |
| # Initialize components | |
| self._load_main_model() | |
| self._load_whisper_model() | |
| self._create_llm_wrapper() | |
| def _load_main_model(self): | |
| """Load the main LLaMA model with extensive error handling""" | |
| try: | |
| print("Loading LLaMA tokenizer...") | |
| tokenizer_kwargs = { | |
| "trust_remote_code": True, | |
| "legacy": False | |
| } | |
| if self.hf_token: | |
| tokenizer_kwargs["token"] = self.hf_token | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| CONFIG["model_name"], | |
| **tokenizer_kwargs | |
| ) | |
| # Ensure pad token exists | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| print("β Tokenizer loaded successfully") | |
| # Load model with conservative settings | |
| print("Loading quantized LLaMA model...") | |
| model_kwargs = { | |
| "use_safetensors": True, | |
| "trust_remote_code": True, | |
| "device": self.device, | |
| "use_triton": False, | |
| "disable_exllama": True, | |
| "disable_exllamav2": True, | |
| "inject_fused_attention": False, | |
| "inject_fused_mlp": False | |
| } | |
| if self.hf_token: | |
| model_kwargs["token"] = self.hf_token | |
| self.model = AutoGPTQForCausalLM.from_quantized( | |
| model_name_or_path=CONFIG["model_name"], | |
| **model_kwargs | |
| ) | |
| self.model.eval() | |
| print("β LLaMA model loaded successfully") | |
| self.model_loaded = True | |
| except Exception as e: | |
| print(f"β Failed to load main model: {e}") | |
| self.model_loaded = False | |
| def _load_whisper_model(self): | |
| """Load Whisper model separately""" | |
| try: | |
| print("Loading Whisper models...") | |
| self.whisper_processor = WhisperProcessor.from_pretrained(CONFIG["whisper_model"]) | |
| self.whisper_model = WhisperForConditionalGeneration.from_pretrained(CONFIG["whisper_model"]) | |
| # Keep Whisper on CPU to save GPU memory | |
| self.whisper_model.to("cpu") | |
| print("β Whisper models loaded successfully") | |
| self.whisper_loaded = True | |
| except Exception as e: | |
| print(f"β Failed to load Whisper: {e}") | |
| self.whisper_loaded = False | |
| def _create_llm_wrapper(self): | |
| """Create LangChain LLM wrapper""" | |
| if self.model_loaded: | |
| try: | |
| self.llm = CustomLlamaLLM(self) | |
| print("β LLM wrapper created successfully") | |
| except Exception as e: | |
| print(f"β Failed to create LLM wrapper: {e}") | |
| def generate_text(self, prompt: str) -> str: | |
| """Generate text with comprehensive error handling""" | |
| if not self.model_loaded: | |
| return "Sorry, the AI model is currently unavailable. Please try again later." | |
| try: | |
| # Prepare inputs | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding=True | |
| ) | |
| # Move to device | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate with torch.no_grad() and conservative settings | |
| with torch.no_grad(): | |
| # Disable autocast completely to avoid the error | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| outputs = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=CONFIG["max_new_tokens"], | |
| do_sample=True, | |
| temperature=CONFIG["temperature"], | |
| top_p=CONFIG["top_p"], | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| use_cache=True | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up response | |
| if "[/INST]" in response: | |
| response = response.split("[/INST]")[-1].strip() | |
| return response if response else "I'm sorry, I couldn't generate a proper response." | |
| except RuntimeError as e: | |
| if "autocast" in str(e).lower() or "scalar" in str(e).lower(): | |
| print(f"Autocast error detected, trying fallback: {e}") | |
| return self._generate_fallback(prompt) | |
| else: | |
| logger.error(f"Runtime error: {e}") | |
| return f"I encountered a technical issue. Please try rephrasing your question." | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| return f"Sorry, I encountered an error: {str(e)[:100]}..." | |
| def _generate_fallback(self, prompt: str) -> str: | |
| """Fallback generation method for autocast issues""" | |
| try: | |
| with torch.no_grad(): | |
| # Simplest possible generation | |
| inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=100, | |
| do_sample=False, # Greedy decoding | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "[/INST]" in response: | |
| response = response.split("[/INST]")[-1].strip() | |
| return response if response else "I apologize, but I'm having technical difficulties." | |
| except Exception as e: | |
| logger.error(f"Fallback generation failed: {e}") | |
| return "I'm experiencing technical difficulties. Please try again later." | |
| def transcribe_audio(self, audio: Dict[str, Any]) -> str: | |
| """Transcribe audio using Whisper""" | |
| if not self.whisper_loaded: | |
| return "Audio transcription unavailable - Whisper model not loaded" | |
| try: | |
| if audio is None or "array" not in audio: | |
| return "No audio detected" | |
| audio_input = self.whisper_processor( | |
| audio["array"], | |
| sampling_rate=audio["sampling_rate"], | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| result = self.whisper_model.generate(**audio_input) | |
| transcription = self.whisper_processor.batch_decode(result, skip_special_tokens=True)[0] | |
| return transcription.strip() | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| return f"Error transcribing audio: {str(e)}" | |
| # -------------------------- | |
| # Custom LangChain LLM Wrapper | |
| # -------------------------- | |
| class CustomLlamaLLM(LLM): | |
| def __init__(self, model_manager: ModelManager): | |
| super().__init__() | |
| self.model_manager = model_manager | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str: | |
| """Call the model with proper error handling""" | |
| try: | |
| output = self.model_manager.generate_text(prompt) | |
| return output | |
| except Exception as e: | |
| logger.error(f"LLM call error: {e}") | |
| return f"I apologize, but I encountered an error: {str(e)}" | |
| def _identifying_params(self) -> Dict[str, Any]: | |
| return {"model_name": CONFIG["model_name"]} | |
| def _llm_type(self) -> str: | |
| return "custom_llama_gptq" | |
| # -------------------------- | |
| # Knowledge Base Manager | |
| # -------------------------- | |
| class KnowledgeBaseManager: | |
| def __init__(self, llm): | |
| self.kb_loaded = False | |
| if HuggingFaceEmbeddings is None or Chroma is None: | |
| print("β Knowledge base unavailable - missing dependencies") | |
| return | |
| try: | |
| print("Initializing knowledge base...") | |
| self.embedding_model = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"]) | |
| self.persist_dir = CONFIG["persist_dir"] | |
| os.makedirs(self.persist_dir, exist_ok=True) | |
| self.vector_db = Chroma( | |
| persist_directory=self.persist_dir, | |
| embedding_function=self.embedding_model | |
| ) | |
| self.retriever = self.vector_db.as_retriever() | |
| if llm: | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=self.retriever, | |
| return_source_documents=True | |
| ) | |
| self.kb_loaded = True | |
| print("β Knowledge base initialized successfully") | |
| else: | |
| print("β Knowledge base unavailable - no LLM provided") | |
| except Exception as e: | |
| print(f"β Knowledge base initialization failed: {e}") | |
| self.kb_loaded = False | |
| def upload_and_index_pdf(self, pdf_file) -> str: | |
| """Upload and index a PDF file""" | |
| if not self.kb_loaded: | |
| return "Knowledge base unavailable - initialization failed" | |
| try: | |
| if pdf_file is None: | |
| return "No file uploaded." | |
| print(f"Processing PDF: {pdf_file.name}") | |
| loader = PyPDFLoader(pdf_file.name) | |
| pages = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CONFIG["chunk_size"], | |
| chunk_overlap=CONFIG["chunk_overlap"] | |
| ) | |
| docs = text_splitter.split_documents(pages) | |
| self.vector_db.add_documents(docs) | |
| result = f"β Successfully indexed: {os.path.basename(pdf_file.name)} ({len(docs)} chunks)" | |
| print(result) | |
| return result | |
| except Exception as e: | |
| error_msg = f"β Error processing PDF: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def query_knowledge_base(self, query: str) -> str: | |
| """Query the knowledge base""" | |
| if not self.kb_loaded: | |
| return "Knowledge base unavailable. Please upload PDFs first or check initialization." | |
| try: | |
| if not query.strip(): | |
| return "Please provide a question." | |
| # Use invoke method for newer LangChain versions | |
| try: | |
| result = self.qa_chain.invoke({"query": query}) | |
| except AttributeError: | |
| # Fallback for older versions | |
| result = self.qa_chain({"query": query}) | |
| return result["result"] | |
| except Exception as e: | |
| error_msg = f"Error querying knowledge base: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # -------------------------- | |
| # Chat Handler | |
| # -------------------------- | |
| class ChatHandler: | |
| def __init__(self, model_manager: ModelManager): | |
| self.model_manager = model_manager | |
| self.tts = TTSManager() | |
| def handle_user_input(self, text: str) -> str: | |
| """Handle user text input""" | |
| if not self.model_manager.model_loaded: | |
| return "AI chat is currently unavailable. The language model failed to load." | |
| if not text.strip(): | |
| return "Please provide a message." | |
| # Format prompt for Llama 2 Chat | |
| prompt = f"<s>[INST] {text} [/INST]" | |
| response = self.model_manager.generate_text(prompt) | |
| return response | |
| def handle_voice_chat(self, audio): | |
| """Handle voice input and return voice + text response""" | |
| try: | |
| if audio is None: | |
| return None, "No audio detected." | |
| if not self.model_manager.whisper_loaded: | |
| return None, "Voice transcription unavailable - Whisper model not loaded" | |
| # Transcribe user speech | |
| user_text = self.model_manager.transcribe_audio(audio) | |
| if "Error" in user_text or "unavailable" in user_text: | |
| return None, user_text | |
| # Get bot response | |
| response = self.handle_user_input(user_text) | |
| # Convert to speech | |
| audio_path = self.tts.speak_text_to_audio(response) | |
| # Combine text | |
| combined_text = f"You said: {user_text}\n\nBot: {response}" | |
| return audio_path, combined_text | |
| except Exception as e: | |
| error_msg = f"Voice chat error: {str(e)}" | |
| logger.error(error_msg) | |
| return None, error_msg | |
| def get_response_and_speak(self, text: str): | |
| """Get text response and convert to speech""" | |
| try: | |
| response = self.handle_user_input(text) | |
| audio_path = self.tts.speak_text_to_audio(response) | |
| return response, audio_path | |
| except Exception as e: | |
| error_msg = f"Response error: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, None | |
| # -------------------------- | |
| # Utility Functions | |
| # -------------------------- | |
| def generate_test_questions_from_pdf(pdf_file): | |
| """Generate sample questions from PDF""" | |
| if pdf_file is None: | |
| return "No PDF uploaded." | |
| sample_questions = [ | |
| "What are the main services offered in this document?", | |
| "What is the refund and cancellation policy?", | |
| "How can customers contact support?", | |
| "What are the business operating hours?", | |
| "What payment methods are accepted?", | |
| "Are there any special offers or discounts mentioned?", | |
| "What are the terms and conditions?", | |
| "How does the booking process work?", | |
| "What documentation is required?", | |
| "Are there any age restrictions or limitations?", | |
| "What is the privacy policy?", | |
| "How are complaints handled?" | |
| ] | |
| return "\n".join([f"{i+1}. {q}" for i, q in enumerate(sample_questions)]) | |
| # -------------------------- | |
| # Initialize Application Components | |
| # -------------------------- | |
| print("Initializing application components...") | |
| # Initialize model manager | |
| model_manager = ModelManager() | |
| # Initialize chat handler | |
| chat_handler = ChatHandler(model_manager) | |
| # Initialize knowledge base | |
| kb_manager = KnowledgeBaseManager(model_manager.llm if model_manager.model_loaded else None) | |
| # Print status summary | |
| print("\n===== Initialization Summary =====") | |
| print(f"Main Model: {'β Loaded' if model_manager.model_loaded else 'β Failed'}") | |
| print(f"Whisper: {'β Loaded' if model_manager.whisper_loaded else 'β Failed'}") | |
| print(f"Knowledge Base: {'β Ready' if kb_manager.kb_loaded else 'β Failed'}") | |
| print("==================================\n") | |
| # -------------------------- | |
| # Gradio Interface | |
| # -------------------------- | |
| def create_gradio_interface(): | |
| """Create the Gradio interface with proper error handling""" | |
| with gr.Blocks( | |
| title="GenAI Customer Support", | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden} .gradio-container {max-width: 1200px; margin: auto;}" | |
| ) as demo: | |
| # Header with status | |
| status_items = [] | |
| if model_manager.model_loaded: | |
| status_items.append("π€ AI Chat") | |
| if model_manager.whisper_loaded: | |
| status_items.append("π€ Voice Recognition") | |
| if kb_manager.kb_loaded: | |
| status_items.append("π Knowledge Base") | |
| status_text = " | ".join(status_items) if status_items else "β οΈ Limited functionality" | |
| gr.Markdown(f""" | |
| # π€ LLaMA 2 Customer Support Chatbot | |
| **Status**: {status_text} | |
| Welcome to your AI-powered customer support assistant! Choose from the available features below. | |
| """) | |
| # Voice Chat Tab | |
| with gr.Tab("π Voice Chat"): | |
| gr.Markdown("### π€ Speak your question and get an audio response") | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_audio = gr.Audio( | |
| type="numpy", | |
| label="π€ Record your question" | |
| ) | |
| submit_voice = gr.Button("π£οΈ Process Voice", variant="primary", size="lg") | |
| with gr.Column(): | |
| bot_audio = gr.Audio(label="π Bot Response", type="filepath") | |
| bot_text = gr.Textbox(label="π Conversation Transcript", lines=8) | |
| submit_voice.click( | |
| fn=chat_handler.handle_voice_chat, | |
| inputs=user_audio, | |
| outputs=[bot_audio, bot_text] | |
| ) | |
| # Text Chat Tab | |
| with gr.Tab("π¬ Text Chat"): | |
| gr.Markdown("### π Type your question and get text + audio response") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| user_input = gr.Textbox( | |
| placeholder="Ask about services, policies, booking, etc.", | |
| label="π Your Question", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| chat_submit = gr.Button("π¬ Send", variant="primary", size="lg") | |
| bot_response = gr.Textbox(label="π€ Bot Response", lines=6) | |
| bot_audio_tab = gr.Audio(label="π Spoken Response", type="filepath") | |
| # Handle Enter key | |
| user_input.submit( | |
| fn=chat_handler.get_response_and_speak, | |
| inputs=user_input, | |
| outputs=[bot_response, bot_audio_tab] | |
| ) | |
| chat_submit.click( | |
| fn=chat_handler.get_response_and_speak, | |
| inputs=user_input, | |
| outputs=[bot_response, bot_audio_tab] | |
| ) | |
| # PDF Knowledge Base Tab | |
| with gr.Tab("π Knowledge Base"): | |
| gr.Markdown("### π Query your uploaded PDF documents") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| pdf_input = gr.Textbox( | |
| placeholder="Ask questions about your uploaded PDFs...", | |
| label="π Your Question", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| pdf_submit = gr.Button("π Search", variant="primary", size="lg") | |
| pdf_response = gr.Textbox(label="π Knowledge Base Answer", lines=8) | |
| pdf_input.submit( | |
| fn=kb_manager.query_knowledge_base, | |
| inputs=pdf_input, | |
| outputs=pdf_response | |
| ) | |
| pdf_submit.click( | |
| fn=kb_manager.query_knowledge_base, | |
| inputs=pdf_input, | |
| outputs=pdf_response | |
| ) | |
| # PDF Upload Tab | |
| with gr.Tab("π Upload Documents"): | |
| gr.Markdown("### π Add new PDF documents to your knowledge base") | |
| with gr.Column(): | |
| pdf_file = gr.File( | |
| label="π Select PDF File", | |
| file_types=[".pdf"], | |
| file_count="single" | |
| ) | |
| upload_button = gr.Button("β¬οΈ Upload & Index PDF", variant="primary", size="lg") | |
| upload_result = gr.Textbox(label="π Upload Status", lines=4) | |
| upload_button.click( | |
| fn=kb_manager.upload_and_index_pdf, | |
| inputs=pdf_file, | |
| outputs=upload_result | |
| ) | |
| # Test Questions Tab | |
| with gr.Tab("π Sample Questions"): | |
| gr.Markdown("### β¨ Generate sample questions for testing your knowledge base") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pdf_file_for_qs = gr.File( | |
| label="π Upload PDF (optional)", | |
| file_types=[".pdf"] | |
| ) | |
| gen_qs_button = gr.Button("β¨ Generate Questions", variant="primary") | |
| with gr.Column(): | |
| generated_questions = gr.Textbox( | |
| label="β Sample Questions", | |
| lines=15, | |
| placeholder="Generated questions will appear here..." | |
| ) | |
| gen_qs_button.click( | |
| fn=generate_test_questions_from_pdf, | |
| inputs=pdf_file_for_qs, | |
| outputs=generated_questions | |
| ) | |
| return demo | |
| # -------------------------- | |
| # Launch Application | |
| # -------------------------- | |
| if __name__ == "__main__": | |
| print("π Starting Gradio interface...") | |
| try: | |
| demo = create_gradio_interface() | |
| print("β Interface created successfully") | |
| print("π Launching application...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, # Set to True if you want a public link | |
| show_error=True, | |
| enable_queue=True, | |
| max_threads=10 | |
| ) | |
| except Exception as e: | |
| print(f"β Failed to launch application: {e}") | |
| logger.error(f"Launch failed: {e}") | |
| # Emergency fallback interface | |
| try: | |
| print("π Attempting emergency fallback...") | |
| fallback_demo = gr.Interface( | |
| fn=lambda x: "Application is in recovery mode. Please check the logs and restart.", | |
| inputs=gr.Textbox(label="Input", placeholder="Application in recovery mode"), | |
| outputs=gr.Textbox(label="Output"), | |
| title="Customer Support Bot - Recovery Mode" | |
| ) | |
| fallback_demo.launch(server_name="0.0.0.0", server_port=7860) | |
| except: | |
| print("π₯ Complete failure - unable to start any interface") |