| |
| from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks |
| from fastapi.middleware.cors import CORSMiddleware |
| import gradio as gr |
| from services.chat_service import ChatService |
| from services.model_service import ModelService |
| from services.pdf_service import PDFService |
| from services.data_service import DataService |
| from services.faq_service import FAQService |
| from auth.auth_handler import get_api_key |
| from models.base_models import UserInput, SearchQuery |
| import logging |
| import asyncio |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('chatbot.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI(title="Bofrost Chat API", version="2.0.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| model_service = ModelService() |
| data_service = DataService(model_service) |
| pdf_service = PDFService(model_service) |
| faq_service = FAQService(model_service) |
| chat_service = ChatService(model_service, data_service, pdf_service, faq_service) |
|
|
| |
| @app.post("/api/chat") |
| async def chat_endpoint( |
| background_tasks: BackgroundTasks, |
| user_input: UserInput, |
| api_key: str = Depends(get_api_key) |
| ): |
| try: |
| response, updated_history, search_results = await chat_service.chat( |
| user_input.user_input, |
| user_input.chat_history |
| ) |
| return { |
| "status": "success", |
| "response": response, |
| "chat_history": updated_history, |
| "search_results": search_results |
| } |
| except Exception as e: |
| logger.error(f"Error in chat endpoint: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/search") |
| async def search_endpoint( |
| query: SearchQuery, |
| api_key: str = Depends(get_api_key) |
| ): |
| try: |
| results = await data_service.search(query.query, query.top_k) |
| return {"results": results} |
| except Exception as e: |
| logger.error(f"Error in search endpoint: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/faq/search") |
| async def faq_search_endpoint( |
| query: SearchQuery, |
| api_key: str = Depends(get_api_key) |
| ): |
| try: |
| results = await faq_service.search_faqs(query.query, query.top_k) |
| return {"results": results} |
| except Exception as e: |
| logger.error(f"Error in FAQ search endpoint: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| def create_gradio_interface(): |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🦙 * Chat Assistant\nFragen Sie nach Produkten, Rezepten und mehr!") |
| |
| with gr.Row(): |
| with gr.Column(scale=4): |
| chat_display = gr.Chatbot(label="Chat-Verlauf", height=400) |
| user_input = gr.Textbox( |
| label="Ihre Nachricht", |
| placeholder="Stellen Sie Ihre Frage...", |
| lines=2 |
| ) |
| |
| with gr.Column(scale=2): |
| with gr.Accordion("Zusätzliche Informationen", open=False): |
| product_info = gr.JSON(label="Produktdetails") |
| |
| with gr.Row(): |
| submit_btn = gr.Button("Senden", variant="primary") |
| clear_btn = gr.Button("Chat löschen") |
| |
| chat_history = gr.State([]) |
| |
| async def respond(message, history): |
| response, updated_history, search_results = await chat_service.chat(message, history) |
| |
| if isinstance(updated_history[0], dict): |
| formatted_history = [(item['user_input'], item['response']) for item in updated_history] |
| elif isinstance(updated_history[0], tuple): |
| formatted_history = [(item[0], item[1]) for item in updated_history] |
| else: |
| raise TypeError("Unexpected structure for updated_history") |
| |
| return formatted_history, updated_history, search_results |
|
|
| |
| submit_btn.click( |
| respond, |
| inputs=[user_input, chat_history], |
| outputs=[chat_display, chat_history, product_info] |
| ) |
| |
| clear_btn.click( |
| lambda: ([], [], None), |
| outputs=[chat_display, chat_history, product_info] |
| ) |
| |
| demo.queue() |
| return demo |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| |
| |
| demo = create_gradio_interface() |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
| |
| |
| uvicorn.run(app, host="0.0.0.0", port=8000) |