Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import json | |
| from dotenv import load_dotenv | |
| from langchain_groq import ChatGroq | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import uvicorn | |
| from deep_translator import GoogleTranslator | |
| from gtts import gTTS | |
| import base64 | |
| from io import BytesIO | |
| # Load environment variables | |
| load_dotenv() | |
| # Supported Indian Languages | |
| SUPPORTED_LANGUAGES = { | |
| 'en': 'English', | |
| 'hi': 'Hindi', | |
| 'te': 'Telugu', | |
| 'ta': 'Tamil', | |
| 'ml': 'Malayalam', | |
| 'kn': 'Kannada', | |
| 'bn': 'Bengali', | |
| 'mr': 'Marathi', | |
| 'gu': 'Gujarati', | |
| 'pa': 'Punjabi', | |
| 'ur': 'Urdu', | |
| 'or': 'Odia', | |
| 'as': 'Assamese' | |
| } | |
| class TranslationService: | |
| """Service for translating text between languages""" | |
| def translate_text(text: str, source_lang: str, target_lang: str) -> str: | |
| """ | |
| Translate text from source language to target language | |
| Args: | |
| text: Text to translate | |
| source_lang: Source language code (e.g., 'hi', 'te') | |
| target_lang: Target language code (e.g., 'en') | |
| Returns: | |
| Translated text | |
| """ | |
| if source_lang == target_lang: | |
| return text | |
| try: | |
| translator = GoogleTranslator(source=source_lang, target=target_lang) | |
| translated = translator.translate(text) | |
| return translated | |
| except Exception as e: | |
| print(f"Translation error ({source_lang} -> {target_lang}): {e}") | |
| return text # Return original text if translation fails | |
| def text_to_speech(text: str, lang_code: str) -> str: | |
| """ | |
| Convert text to speech and return base64 encoded audio | |
| Args: | |
| text: Text to convert to speech | |
| lang_code: Language code for TTS | |
| Returns: | |
| Base64 encoded MP3 audio | |
| """ | |
| try: | |
| # Create TTS | |
| tts = gTTS(text=text, lang=lang_code, slow=False) | |
| # Save to BytesIO buffer | |
| audio_buffer = BytesIO() | |
| tts.write_to_fp(audio_buffer) | |
| audio_buffer.seek(0) | |
| # Encode to base64 | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| return audio_base64 | |
| except Exception as e: | |
| print(f"TTS error for language {lang_code}: {e}") | |
| return "" | |
| class GovernmentSchemesRAG: | |
| def __init__(self): | |
| self.groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not self.groq_api_key or self.groq_api_key == "your_groq_api_key_here": | |
| raise ValueError("Please set your GROQ_API_KEY in the .env file. Get it from https://console.groq.com/") | |
| # Initialize embeddings (free HuggingFace model) | |
| print("Loading embedding model...") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| # Initialize LLM (free Groq API) | |
| self.llm = ChatGroq( | |
| temperature=0.3, | |
| model_name="llama-3.3-70b-versatile", # Latest free tier model | |
| groq_api_key=self.groq_api_key | |
| ) | |
| self.vectorstore = None | |
| self.qa_chain = None | |
| def load_vectorstore(self): | |
| """ | |
| Load existing vector database from disk | |
| """ | |
| print("Loading vector database from ./chroma_db/...") | |
| if not os.path.exists("./chroma_db"): | |
| raise FileNotFoundError( | |
| "Vector database not found! Please run 'python setup_db.py' first to create the database." | |
| ) | |
| self.vectorstore = Chroma( | |
| persist_directory="./chroma_db", | |
| embedding_function=self.embeddings | |
| ) | |
| print("โ Vector database loaded successfully!") | |
| return self.vectorstore | |
| def setup_qa_chain(self): | |
| """ | |
| Setup the QA chain with custom prompt | |
| """ | |
| # Custom prompt template for government schemes | |
| prompt_template = """You are a helpful assistant that provides information about Indian government schemes. | |
| Use the following pieces of context to answer the question at the end. | |
| If you don't find the exact answer in the context, provide the most relevant schemes based on the available information. | |
| Context: {context} | |
| Question: {question} | |
| Instructions: | |
| 1. Identify the key requirements from the question (age group, education level, state, category, etc.) | |
| 2. List all relevant schemes that match the criteria | |
| 3. For each scheme, provide: | |
| - Scheme name | |
| - Eligibility criteria | |
| - Benefits | |
| - How to apply (if mentioned) | |
| 4. If the user mentions a state, prioritize schemes for that state, but also include national schemes | |
| 5. Be specific and helpful in your response | |
| Answer:""" | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| self.qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.vectorstore.as_retriever( | |
| search_kwargs={"k": 5} # Retrieve top 5 relevant documents | |
| ), | |
| chain_type_kwargs={"prompt": PROMPT}, | |
| return_source_documents=True | |
| ) | |
| print("QA Chain setup complete!") | |
| def initialize(self): | |
| """ | |
| Initialize the RAG system by loading existing vector database | |
| """ | |
| # Load existing vectorstore from disk | |
| self.load_vectorstore() | |
| # Setup QA chain | |
| self.setup_qa_chain() | |
| print("\nโ RAG System initialized successfully!") | |
| def query(self, question, state=None): | |
| """ | |
| Query the RAG system | |
| """ | |
| if state and state != "All States": | |
| question = f"{question} (User is from {state})" | |
| result = self.qa_chain.invoke({"query": question}) | |
| # Format response | |
| response = result['result'] | |
| # Add source information | |
| source_info = "\n\n๐ Sources:\n" | |
| for i, doc in enumerate(result['source_documents'][:3], 1): | |
| source_info += f"{i}. {doc.page_content[:150]}...\n" | |
| return response + source_info | |
| # Initialize the RAG system | |
| print("๐ Initializing Government Schemes RAG System...") | |
| try: | |
| rag_system = GovernmentSchemesRAG() | |
| rag_system.initialize() | |
| except FileNotFoundError as e: | |
| print("\n" + "="*80) | |
| print("โ ERROR: Vector database not found!") | |
| print("="*80) | |
| print("\n๐ Please run the setup script first:") | |
| print(" python setup_db.py") | |
| print("\nThis will create the vector database from updated_data.csv") | |
| print("="*80) | |
| exit(1) | |
| except Exception as e: | |
| print(f"\nโ Error initializing RAG system: {e}") | |
| exit(1) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Government Schemes RAG API", | |
| description="API for querying Indian Government Schemes using RAG", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request models | |
| class QueryRequest(BaseModel): | |
| question: str | |
| state: Optional[str] = None | |
| language: Optional[str] = "en" # User's selected language (default: English) | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[str] | |
| class AudioRequest(BaseModel): | |
| text: str | |
| language: Optional[str] = "en" | |
| class AudioResponse(BaseModel): | |
| audio: str # Base64 encoded audio | |
| # Indian states list | |
| INDIAN_STATES = [ | |
| "All States", | |
| "Andhra Pradesh", "Arunachal Pradesh", "Assam", "Bihar", "Chhattisgarh", | |
| "Goa", "Gujarat", "Haryana", "Himachal Pradesh", "Jharkhand", "Karnataka", | |
| "Kerala", "Madhya Pradesh", "Maharashtra", "Manipur", "Meghalaya", "Mizoram", | |
| "Nagaland", "Odisha", "Punjab", "Rajasthan", "Sikkim", "Tamil Nadu", | |
| "Telangana", "Tripura", "Uttar Pradesh", "Uttarakhand", "West Bengal", | |
| "Andaman and Nicobar Islands", "Chandigarh", "Dadra and Nagar Haveli and Daman and Diu", | |
| "Delhi", "Jammu and Kashmir", "Ladakh", "Lakshadweep", "Puducherry" | |
| ] | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint - API information""" | |
| return { | |
| "message": "Government Schemes RAG API with Multilingual Support", | |
| "version": "2.0.0", | |
| "supported_languages": SUPPORTED_LANGUAGES, | |
| "endpoints": { | |
| "POST /query": "Query government schemes with translation support", | |
| "POST /generate-audio": "Generate audio from text (on-demand)", | |
| "GET /states": "Get list of Indian states", | |
| "GET /languages": "Get list of supported languages", | |
| "GET /health": "Health check" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "rag_system": "initialized" if rag_system.qa_chain is not None else "not initialized" | |
| } | |
| async def get_languages(): | |
| """Get list of supported languages""" | |
| return { | |
| "languages": SUPPORTED_LANGUAGES | |
| } | |
| async def get_states(): | |
| """Get list of Indian states""" | |
| return { | |
| "states": INDIAN_STATES | |
| } | |
| async def query_schemes(request: QueryRequest): | |
| """ | |
| Query government schemes with multilingual support | |
| - **question**: The question about government schemes (in any supported language) | |
| - **state**: Optional state filter (default: None) | |
| - **language**: Language code for input/output (default: 'en') | |
| Flow: | |
| 1. Translate user question from selected language to English | |
| 2. Query RAG system (in English) | |
| 3. Translate answer back to user's selected language | |
| 4. Optionally generate audio response | |
| """ | |
| if not request.question.strip(): | |
| raise HTTPException(status_code=400, detail="Question cannot be empty") | |
| # Validate language code | |
| if request.language not in SUPPORTED_LANGUAGES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported language. Supported: {list(SUPPORTED_LANGUAGES.keys())}" | |
| ) | |
| try: | |
| # Step 1: Translate input question to English (if not already in English) | |
| if request.language != 'en': | |
| print(f"Translating question from {SUPPORTED_LANGUAGES[request.language]} to English...") | |
| english_question = TranslationService.translate_text( | |
| request.question, | |
| source_lang=request.language, | |
| target_lang='en' | |
| ) | |
| print(f"Original: {request.question}") | |
| print(f"English: {english_question}") | |
| else: | |
| english_question = request.question | |
| # Step 2: Query the RAG system (in English) | |
| print(f"Querying RAG system with: {english_question}") | |
| result = rag_system.qa_chain.invoke({"query": english_question}) | |
| # Extract English answer | |
| answer_english = result['result'] | |
| # Step 3: Translate answer back to user's language (if not English) | |
| if request.language != 'en': | |
| print(f"Translating answer to {SUPPORTED_LANGUAGES[request.language]}...") | |
| final_answer = TranslationService.translate_text( | |
| answer_english, | |
| source_lang='en', | |
| target_lang=request.language | |
| ) | |
| else: | |
| final_answer = answer_english | |
| # Step 4: Extract sources | |
| sources = [] | |
| for doc in result['source_documents'][:3]: | |
| sources.append(doc.page_content[:200] + "...") | |
| # Note: Audio is NOT generated automatically | |
| # User must call /generate-audio endpoint when they click the speaker button | |
| return QueryResponse( | |
| answer=final_answer, | |
| sources=sources | |
| ) | |
| except Exception as e: | |
| print(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| async def generate_audio(request: AudioRequest): | |
| """ | |
| Generate audio from text (called when user clicks speaker button) | |
| - **text**: The text to convert to speech | |
| - **language**: Language code for TTS (default: 'en') | |
| This endpoint should be called ONLY when user explicitly clicks | |
| the "Play Audio" or speaker button on the UI. | |
| """ | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Validate language code | |
| if request.language not in SUPPORTED_LANGUAGES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported language. Supported: {list(SUPPORTED_LANGUAGES.keys())}" | |
| ) | |
| try: | |
| print(f"Generating audio for language: {SUPPORTED_LANGUAGES[request.language]}") | |
| # Generate audio | |
| audio_base64 = TranslationService.text_to_speech(request.text, request.language) | |
| if not audio_base64: | |
| raise HTTPException(status_code=500, detail="Failed to generate audio") | |
| return AudioResponse(audio=audio_base64) | |
| except Exception as e: | |
| print(f"Error generating audio: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
| # Request/Response models for schemes | |
| class SchemeFilterRequest(BaseModel): | |
| state: Optional[str] = None | |
| category: Optional[str] = None | |
| level: Optional[str] = None # Central, State | |
| search_text: Optional[str] = None | |
| page: Optional[int] = 1 | |
| page_size: Optional[int] = 10 | |
| class SchemeDetail(BaseModel): | |
| scheme_name: str | |
| slug: str | |
| details: str | |
| benefits: str | |
| eligibility: str | |
| application: str | |
| documents: str | |
| level: str | |
| scheme_category: str | |
| tags: str | |
| class SchemeListResponse(BaseModel): | |
| total: int | |
| page: int | |
| page_size: int | |
| total_pages: int | |
| schemes: List[SchemeDetail] | |
| async def get_all_schemes( | |
| page: int = 1, | |
| page_size: int = 10, | |
| state: Optional[str] = None, | |
| category: Optional[str] = None, | |
| level: Optional[str] = None, | |
| search: Optional[str] = None | |
| ): | |
| """ | |
| Get all schemes with optional filtering and pagination | |
| - **page**: Page number (default: 1) | |
| - **page_size**: Number of schemes per page (default: 10, max: 100) | |
| - **state**: Filter by state (optional) | |
| - **category**: Filter by category (optional) | |
| - **level**: Filter by level - Central/State (optional) | |
| - **search**: Search in scheme name, details, benefits (optional) | |
| Returns paginated list of schemes with filtering options. | |
| **Recommendation**: Use BACKEND filtering for better performance and consistency. | |
| """ | |
| try: | |
| # Load the CSV file | |
| df = pd.read_csv('updated_data.csv') | |
| # Apply filters | |
| filtered_df = df.copy() | |
| # Filter by state (case-insensitive partial match) | |
| if state and state != "All States": | |
| filtered_df = filtered_df[ | |
| filtered_df['details'].str.contains(state, case=False, na=False) | | |
| filtered_df['eligibility'].str.contains(state, case=False, na=False) | |
| ] | |
| # Filter by category | |
| if category: | |
| filtered_df = filtered_df[ | |
| filtered_df['schemeCategory'].str.contains(category, case=False, na=False) | |
| ] | |
| # Filter by level | |
| if level: | |
| filtered_df = filtered_df[ | |
| filtered_df['level'].str.lower() == level.lower() | |
| ] | |
| # Search across multiple fields | |
| if search: | |
| search_mask = ( | |
| filtered_df['scheme_name'].str.contains(search, case=False, na=False) | | |
| filtered_df['details'].str.contains(search, case=False, na=False) | | |
| filtered_df['benefits'].str.contains(search, case=False, na=False) | | |
| filtered_df['tags'].str.contains(search, case=False, na=False) | |
| ) | |
| filtered_df = filtered_df[search_mask] | |
| # Calculate pagination | |
| total_schemes = len(filtered_df) | |
| page_size = min(page_size, 100) # Max 100 per page | |
| total_pages = (total_schemes + page_size - 1) // page_size | |
| # Get paginated data | |
| start_idx = (page - 1) * page_size | |
| end_idx = start_idx + page_size | |
| paginated_df = filtered_df.iloc[start_idx:end_idx] | |
| # Convert to list of dicts | |
| schemes = [] | |
| for _, row in paginated_df.iterrows(): | |
| schemes.append({ | |
| "scheme_name": str(row.get('scheme_name', '')), | |
| "slug": str(row.get('slug', '')), | |
| "details": str(row.get('details', '')), | |
| "benefits": str(row.get('benefits', '')), | |
| "eligibility": str(row.get('eligibility', '')), | |
| "application": str(row.get('application', '')), | |
| "documents": str(row.get('documents', '')), | |
| "level": str(row.get('level', '')), | |
| "scheme_category": str(row.get('schemeCategory', '')), | |
| "tags": str(row.get('tags', '')) | |
| }) | |
| return { | |
| "total": total_schemes, | |
| "page": page, | |
| "page_size": page_size, | |
| "total_pages": total_pages, | |
| "schemes": schemes | |
| } | |
| except Exception as e: | |
| print(f"Error fetching schemes: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching schemes: {str(e)}") | |
| async def get_scheme_by_slug(slug: str, language: Optional[str] = "en"): | |
| """ | |
| Get detailed information about a specific scheme by slug | |
| - **slug**: The unique slug identifier of the scheme | |
| - **language**: Language code for translation (default: 'en') | |
| Returns detailed scheme information in the requested language. | |
| """ | |
| try: | |
| # Load the CSV file | |
| df = pd.read_csv('updated_data.csv') | |
| # Find scheme by slug | |
| scheme_row = df[df['slug'] == slug] | |
| if scheme_row.empty: | |
| raise HTTPException(status_code=404, detail="Scheme not found") | |
| scheme_row = scheme_row.iloc[0] | |
| # Prepare scheme details | |
| scheme_data = { | |
| "scheme_name": str(scheme_row.get('scheme_name', '')), | |
| "slug": str(scheme_row.get('slug', '')), | |
| "details": str(scheme_row.get('details', '')), | |
| "benefits": str(scheme_row.get('benefits', '')), | |
| "eligibility": str(scheme_row.get('eligibility', '')), | |
| "application": str(scheme_row.get('application', '')), | |
| "documents": str(scheme_row.get('documents', '')), | |
| "level": str(scheme_row.get('level', '')), | |
| "scheme_category": str(scheme_row.get('schemeCategory', '')), | |
| "tags": str(scheme_row.get('tags', '')) | |
| } | |
| # Translate if needed | |
| if language != 'en' and language in SUPPORTED_LANGUAGES: | |
| print(f"Translating scheme details to {SUPPORTED_LANGUAGES[language]}...") | |
| scheme_data['scheme_name'] = TranslationService.translate_text( | |
| scheme_data['scheme_name'], 'en', language | |
| ) | |
| scheme_data['details'] = TranslationService.translate_text( | |
| scheme_data['details'], 'en', language | |
| ) | |
| scheme_data['benefits'] = TranslationService.translate_text( | |
| scheme_data['benefits'], 'en', language | |
| ) | |
| scheme_data['eligibility'] = TranslationService.translate_text( | |
| scheme_data['eligibility'], 'en', language | |
| ) | |
| return scheme_data | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error fetching scheme: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching scheme: {str(e)}") | |
| async def get_scheme_categories(): | |
| """ | |
| Get all unique scheme categories available | |
| Returns list of all unique categories from the database. | |
| """ | |
| try: | |
| df = pd.read_csv('updated_data.csv') | |
| # Get unique categories (may contain comma-separated values) | |
| all_categories = set() | |
| for cat in df['schemeCategory'].dropna(): | |
| # Split by comma and strip whitespace | |
| categories = [c.strip() for c in str(cat).split(',')] | |
| all_categories.update(categories) | |
| return { | |
| "categories": sorted(list(all_categories)) | |
| } | |
| except Exception as e: | |
| print(f"Error fetching categories: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching categories: {str(e)}") | |
| async def get_scheme_statistics(): | |
| """ | |
| Get statistics about schemes in the database | |
| Returns: | |
| - Total number of schemes | |
| - Count by level (Central/State) | |
| - Count by categories | |
| - Count by states | |
| """ | |
| try: | |
| df = pd.read_csv('updated_data.csv') | |
| # Total schemes | |
| total = len(df) | |
| # Count by level | |
| level_counts = df['level'].value_counts().to_dict() | |
| # Count by category | |
| category_counts = {} | |
| for cat in df['schemeCategory'].dropna(): | |
| categories = [c.strip() for c in str(cat).split(',')] | |
| for c in categories: | |
| category_counts[c] = category_counts.get(c, 0) + 1 | |
| return { | |
| "total_schemes": total, | |
| "by_level": level_counts, | |
| "by_category": dict(sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:10]), | |
| "total_categories": len(category_counts) | |
| } | |
| except Exception as e: | |
| print(f"Error fetching statistics: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error fetching statistics: {str(e)}") | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("\n๐ Starting FastAPI server...") | |
| print("๐ API Documentation: http://127.0.0.1:7860/docs") | |
| print("๐ Alternative docs: http://127.0.0.1:7860/redoc") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |