Spaces:
Running
Running
| from fastapi import FastAPI, Request, Depends, HTTPException, Header, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from helpmate_ai import get_system_msg, retreive_results, rerank_with_cross_encoder, generate_response, intro_message | |
| import google.generativeai as genai | |
| import os | |
| from dotenv import load_dotenv | |
| import re | |
| import speech_recognition as sr | |
| from io import BytesIO | |
| import wave | |
| import google.generativeai as genai | |
| # Load environment variables | |
| load_dotenv() | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| genai.configure(api_key=gemini_api_key) | |
| # Define a secret API key (use environment variables in production) | |
| API_KEY = os.getenv("API_KEY") | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # # Enable CORS | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], | |
| # allow_credentials=True, | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # Pydantic models for request/response validation | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| message: str | |
| class ChatResponse(BaseModel): | |
| response: str | |
| conversation: List[Message] | |
| class Report(BaseModel): | |
| response: str | |
| message: str | |
| timestamp: str | |
| # Initialize conversation and model | |
| conversation_bot = [] | |
| conversation = get_system_msg() | |
| model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation) | |
| # Initialize speech recognizer | |
| recognizer = sr.Recognizer() | |
| # Dependency to check the API key | |
| async def verify_api_key(x_api_key: str = Header(...)): | |
| if x_api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Unauthorized") | |
| def get_gemini_completions(conversation: str) -> str: | |
| response = model.generate_content(conversation) | |
| return response.text | |
| # @app.get("/secure-endpoint", dependencies=[Depends(verify_api_key)]) | |
| # async def secure_endpoint(): | |
| # return {"message": "Access granted!"} | |
| # Initialize conversation endpoint | |
| async def initialize_chat(): | |
| global conversation_bot | |
| # conversation = "Hi" | |
| # introduction = get_gemini_completions(conversation) | |
| conversation_bot = [Message(role="bot", content=intro_message)] | |
| return ChatResponse( | |
| response=intro_message, | |
| conversation=conversation_bot | |
| ) | |
| # Chat endpoint | |
| async def chat(request: ChatRequest): | |
| global conversation_bot | |
| # Add user message to conversation | |
| user_message = Message(role="user", content=request.message) | |
| conversation_bot.append(user_message) | |
| # Generate response | |
| results_df = retreive_results(request.message) | |
| top_docs = rerank_with_cross_encoder(request.message, results_df) | |
| messages = generate_response(request.message, top_docs) | |
| response_assistant = get_gemini_completions(messages) | |
| # formatted_response = format_rag_response(response_assistant) | |
| # Add bot response to conversation | |
| bot_message = Message(role="bot", content=response_assistant) | |
| conversation_bot.append(bot_message) | |
| return ChatResponse( | |
| response=response_assistant, | |
| conversation=conversation_bot | |
| ) | |
| # Voice processing endpoint | |
| async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]): | |
| # async def process_voice(name: str): | |
| try: | |
| # Read the audio file | |
| contents = await audio_file.read() | |
| audio_data = BytesIO(contents) | |
| # Convert audio to wav format for speech recognition | |
| with sr.AudioFile(audio_data) as source: | |
| audio = recognizer.record(source) | |
| # Perform speech recognition | |
| text = recognizer.recognize_google(audio) | |
| # print(text) | |
| # Process the text through the chat pipeline | |
| results_df = retreive_results(text) | |
| top_docs = rerank_with_cross_encoder(text, results_df) | |
| messages = generate_response(text, top_docs) | |
| response_assistant = get_gemini_completions(messages) | |
| return { | |
| "transcribed_text": text, | |
| "response": response_assistant | |
| } | |
| except Exception as e: | |
| return {"error": f"Error processing voice input: {str(e)}"} | |
| async def handle_feedback( | |
| request: Report, | |
| dependencies=[Depends(verify_api_key)] | |
| ): | |
| # if x_api_key != VALID_API_KEY: | |
| # raise HTTPException(status_code=403, detail="Invalid API key") | |
| # Here you can store the feedback in your database | |
| # For example: | |
| # await db.store_feedback(message, is_positive) | |
| return {"status": "success"} | |
| # Reset conversation endpoint | |
| async def reset_conversation(): | |
| global conversation_bot, conversation | |
| conversation_bot = [] | |
| # conversation = "Hi" | |
| # introduction = get_gemini_completions(conversation) | |
| conversation_bot.append(Message(role="bot", content=intro_message)) | |
| return {"status": "success", "message": "Conversation reset"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |