Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import tempfile | |
| from pathlib import Path | |
| from contextlib import asynccontextmanager | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.requests import Request | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import google.generativeai as genai | |
| from gtts import gTTS | |
| from deep_translator import GoogleTranslator | |
| from app import rag | |
| load_dotenv() | |
| asr_model = None | |
| model_loaded = False | |
| model_loading = False | |
| conversation_history = [] | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| LOCAL_MODEL_PATH = Path(__file__).resolve().parent.parent / "final_model" | |
| HUGGINGFACE_MODEL_ID = "seniruk/whisper-small-si" | |
| IS_HF_SPACE = bool(os.getenv("SPACE_ID")) | |
| def load_asr_model(): | |
| """Load the ASR model - tries local model first, falls back to Hugging Face.""" | |
| global asr_model, model_loaded, model_loading | |
| if model_loaded: | |
| return asr_model | |
| model_loading = True | |
| try: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torch | |
| except Exception as import_error: | |
| model_loading = False | |
| raise RuntimeError( | |
| "ASR dependencies are not installed. Install transformers and torch to enable speech input." | |
| ) from import_error | |
| processor = None | |
| model = None | |
| model_source = None | |
| if LOCAL_MODEL_PATH.exists(): | |
| print("=" * 50) | |
| print(f"Loading ASR model from local path: {LOCAL_MODEL_PATH}") | |
| print("=" * 50) | |
| try: | |
| processor = WhisperProcessor.from_pretrained(str(LOCAL_MODEL_PATH)) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| str(LOCAL_MODEL_PATH), torch_dtype=torch.float32 | |
| ) | |
| model_source = "local" | |
| print("Local model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load local model: {str(e)}") | |
| print("Falling back to Hugging Face model...") | |
| processor = None | |
| model = None | |
| else: | |
| print(f"Local model not found at: {LOCAL_MODEL_PATH}") | |
| print("Falling back to Hugging Face model...") | |
| if model is None: | |
| print("=" * 50) | |
| print(f"Loading ASR model from Hugging Face: {HUGGINGFACE_MODEL_ID}") | |
| print("This may take a minute on first run...") | |
| print("=" * 50) | |
| processor = WhisperProcessor.from_pretrained(HUGGINGFACE_MODEL_ID) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| HUGGINGFACE_MODEL_ID, torch_dtype=torch.float32 | |
| ) | |
| model_source = "huggingface" | |
| print("Hugging Face model loaded successfully.") | |
| model.eval() | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| model = model.half() | |
| model = model.to("cuda") | |
| print("Using GPU with float16 for faster inference.") | |
| else: | |
| print("Running on CPU.") | |
| asr_model = { | |
| "processor": processor, | |
| "model": model, | |
| "device": device, | |
| "source": model_source, | |
| } | |
| model_loaded = True | |
| model_loading = False | |
| print(f"Model ready. (Source: {model_source})") | |
| return asr_model | |
| def transcribe_audio(audio_path: str) -> str: | |
| """Transcribe audio file to text - optimized.""" | |
| global asr_model | |
| try: | |
| import soundfile as sf | |
| import numpy as np | |
| from scipy import signal | |
| import torch | |
| except Exception as import_error: | |
| raise RuntimeError( | |
| "Audio dependencies are not installed. Install soundfile, numpy, and scipy." | |
| ) from import_error | |
| processor = asr_model["processor"] | |
| model = asr_model["model"] | |
| device = asr_model["device"] | |
| audio_array, sample_rate = sf.read(audio_path) | |
| if len(audio_array.shape) > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| if sample_rate != 16000: | |
| num_samples = int(len(audio_array) * 16000 / sample_rate) | |
| audio_array = signal.resample(audio_array, num_samples) | |
| audio_array = audio_array.astype(np.float32) | |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features | |
| if device == "cuda": | |
| inputs = inputs.half().to("cuda") | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| inputs, | |
| max_length=225, | |
| num_beams=1, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip() | |
| async def lifespan(app: FastAPI): | |
| """Load model at startup.""" | |
| print("\nStarting Sinhala Chatbot Server...") | |
| if IS_HF_SPACE: | |
| print("Hugging Face Space detected. Skipping heavy startup preloads.") | |
| else: | |
| load_asr_model() | |
| loaded = rag.load_vector_store() | |
| if not loaded: | |
| rag.rebuild_vector_store_from_pdfs() | |
| print("Server ready.\n") | |
| yield | |
| print("\nShutting down...") | |
| app = FastAPI(title="Sinhala Chatbot", version="1.0.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| BASE_DIR = Path(__file__).resolve().parent | |
| app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static") | |
| templates = Jinja2Templates(directory=BASE_DIR / "templates") | |
| async def home(request: Request): | |
| """Render the main chatbot interface.""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def get_model_status(): | |
| """Check if ASR model is loaded.""" | |
| source = asr_model.get("source", None) if asr_model else None | |
| return JSONResponse({"loaded": model_loaded, "loading": model_loading, "source": source}) | |
| async def speech_to_text(audio: UploadFile = File(...)): | |
| """Convert speech to text using Whisper ASR model.""" | |
| if not model_loaded: | |
| try: | |
| load_asr_model() | |
| except Exception as load_error: | |
| raise HTTPException(status_code=503, detail=str(load_error)) from load_error | |
| try: | |
| audio_bytes = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| tmp_path = tmp_file.name | |
| try: | |
| transcription = transcribe_audio(tmp_path) | |
| return JSONResponse({"success": True, "text": transcription}) | |
| finally: | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| print(f"ASR Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Speech recognition failed: {str(e)}") | |
| async def chat(request: Request): | |
| """ | |
| Send text to RAG system (retrieves from documents first, then falls back to Gemini/HF). | |
| Automatically translates non-English questions to English before RAG processing. | |
| """ | |
| global conversation_history | |
| try: | |
| data = await request.json() | |
| user_message = data.get("message", "") | |
| if not user_message: | |
| raise HTTPException(status_code=400, detail="Message is required") | |
| english_question = user_message | |
| try: | |
| translator = GoogleTranslator(source="auto", target="en") | |
| english_question = translator.translate(user_message) | |
| print(f"Original Question: {user_message}") | |
| print(f"English Question: {english_question}") | |
| except Exception as trans_error: | |
| print(f"Translation failed, using original: {trans_error}") | |
| english_question = user_message | |
| rag_result = rag.rag_answer(english_question) | |
| assistant_message = rag_result.get("answer", "") | |
| conversation_history.append({ | |
| "role": "user", | |
| "parts": [user_message], | |
| }) | |
| conversation_history.append({ | |
| "role": "model", | |
| "parts": [assistant_message], | |
| }) | |
| if len(conversation_history) > 20: | |
| conversation_history = conversation_history[-20:] | |
| return JSONResponse({ | |
| "success": True, | |
| "response": assistant_message, | |
| "source": rag_result.get("source", "none"), | |
| "context_found": rag_result.get("context_found", False), | |
| }) | |
| except Exception as e: | |
| print(f"Chat Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") | |
| async def text_to_speech(request: Request): | |
| """Convert text to speech using Google TTS.""" | |
| try: | |
| data = await request.json() | |
| text = data.get("text", "") | |
| lang = data.get("lang", "si") | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Text is required") | |
| tts = gTTS(text=text, lang=lang, slow=False) | |
| audio_buffer = io.BytesIO() | |
| tts.write_to_fp(audio_buffer) | |
| audio_buffer.seek(0) | |
| return StreamingResponse( | |
| audio_buffer, | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "inline; filename=speech.mp3"}, | |
| ) | |
| except Exception as e: | |
| print(f"TTS Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Text-to-speech failed: {str(e)}") | |
| async def clear_history(): | |
| """Clear conversation history.""" | |
| global conversation_history | |
| conversation_history = [] | |
| return JSONResponse({"success": True, "message": "Conversation history cleared"}) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return JSONResponse({ | |
| "status": "healthy", | |
| "gemini_configured": GEMINI_API_KEY is not None, | |
| }) | |
| async def translate_to_english(request: Request): | |
| """Translate Sinhala/mixed language question to full English using Google Translate.""" | |
| try: | |
| data = await request.json() | |
| question = data.get("question", "") | |
| if not question: | |
| raise HTTPException(status_code=400, detail="Question is required") | |
| translator = GoogleTranslator(source="auto", target="en") | |
| english_question = translator.translate(question) | |
| print(f"Original: {question}") | |
| print(f"Translated: {english_question}") | |
| return JSONResponse({"success": True, "english_question": english_question, "translated": True}) | |
| except Exception as e: | |
| print(f"Translation Error: {str(e)}") | |
| error_msg = str(e) | |
| return JSONResponse({ | |
| "success": False, | |
| "english_question": question, | |
| "translated": False, | |
| "error": error_msg, | |
| }) | |
| async def upload_pdf(file: UploadFile = File(...)): | |
| """Upload a PDF file for RAG processing.""" | |
| if not file.filename.lower().endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Only PDF files are allowed") | |
| try: | |
| rag_data_dir = Path(__file__).resolve().parent.parent / "rag_data" | |
| rag_data_dir.mkdir(parents=True, exist_ok=True) | |
| pdf_path = rag_data_dir / file.filename | |
| content = await file.read() | |
| with open(pdf_path, "wb") as f: | |
| f.write(content) | |
| chunks = rag.load_and_process_pdf(str(pdf_path)) | |
| if not chunks: | |
| raise HTTPException(status_code=400, detail="Could not extract text from PDF") | |
| success = rag.create_vector_store(chunks) | |
| if success: | |
| status = rag.get_rag_status() | |
| return JSONResponse({ | |
| "success": True, | |
| "message": f"PDF '{file.filename}' uploaded and processed successfully", | |
| "chunks_created": len(chunks), | |
| "total_documents": status.get("documents_count", 0), | |
| }) | |
| raise HTTPException(status_code=500, detail="Failed to create vector store") | |
| except Exception as e: | |
| print(f"RAG Upload Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to process PDF: {str(e)}") | |
| async def rag_ask(request: Request): | |
| """Ask a question using RAG - first checks database, then falls back to Gemini/HF.""" | |
| try: | |
| data = await request.json() | |
| question = data.get("question", "") | |
| response_lang = data.get("response_lang", "en") | |
| print(f"Question: {question}") | |
| print(f"Response language: {response_lang}") | |
| if not question: | |
| raise HTTPException(status_code=400, detail="Question is required") | |
| result = rag.rag_answer(question) | |
| answer = result["answer"] | |
| print(f"Original answer length: {len(answer) if answer else 0}") | |
| if response_lang == "si-en" and answer: | |
| print("Translating to Sinhala+English...") | |
| try: | |
| translator = GoogleTranslator(source="en", target="si") | |
| sinhala_answer = translator.translate(answer) | |
| answer = f"**Sinhala:**\n{sinhala_answer}\n\n---\n\n**English:**\n{answer}" | |
| print("Translated successfully.") | |
| except Exception as trans_err: | |
| print(f"Translation to Sinhala failed: {trans_err}") | |
| answer = f"Translation failed: {trans_err}\n\n**English:** {answer}" | |
| return JSONResponse({ | |
| "success": True, | |
| "question": question, | |
| "answer": answer, | |
| "source": result["source"], | |
| "context_found": result["context_found"], | |
| "relevance_score": result["relevance_score"], | |
| "response_lang": response_lang, | |
| }) | |
| except Exception as e: | |
| print(f"RAG Ask Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}") | |
| async def rag_status(): | |
| """Get RAG system status.""" | |
| return JSONResponse(rag.get_rag_status()) | |
| async def clear_rag(): | |
| """Clear all RAG data.""" | |
| rag.clear_rag_data() | |
| return JSONResponse({"success": True, "message": "RAG data cleared"}) | |
| async def rebuild_rag(): | |
| """Rebuild vector store from all PDFs in rag_data directory.""" | |
| success = rag.rebuild_vector_store_from_pdfs() | |
| if not success: | |
| return JSONResponse( | |
| { | |
| "success": False, | |
| "message": "No valid PDFs found to rebuild vector store.", | |
| } | |
| ) | |
| return JSONResponse({"success": True, "message": "RAG vector store rebuilt successfully."}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) | |