Spaces:
Runtime error
Runtime error
| from huggingface_hub import login | |
| from fastapi import FastAPI, Depends, HTTPException | |
| import logging | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import models, SentenceTransformer | |
| from services.qdrant_searcher import QdrantSearcher | |
| from services.openai_service import generate_rag_response | |
| from utils.auth import token_required | |
| from dotenv import load_dotenv | |
| import os | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Initialize FastAPI application | |
| app = FastAPI() | |
| # Set the cache directory for Hugging Face | |
| os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
| # Ensure the cache directory exists | |
| hf_home_dir = os.environ["HF_HOME"] | |
| if not os.path.exists(hf_home_dir): | |
| os.makedirs(hf_home_dir) | |
| # Setup logging using Python's standard logging library | |
| logging.basicConfig(level=logging.INFO) | |
| # Load Hugging Face token from environment variable | |
| huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| if huggingface_token: | |
| try: | |
| login(token=huggingface_token, add_to_git_credential=True) | |
| logging.info("Successfully logged into Hugging Face Hub.") | |
| except Exception as e: | |
| logging.error(f"Failed to log into Hugging Face Hub: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") | |
| else: | |
| raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") | |
| # Initialize the Qdrant searcher | |
| qdrant_url = os.getenv('QDRANT_URL') | |
| access_token = os.getenv('QDRANT_ACCESS_TOKEN') | |
| if not qdrant_url or not access_token: | |
| raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") | |
| # Initialize the SentenceTransformer model with trust_remote_code using transformers | |
| try: | |
| cache_folder = os.path.join(hf_home_dir, "transformers_cache") | |
| # Load the tokenizer and model with trust_remote_code=True | |
| tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) | |
| model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) | |
| # Wrap the model into a SentenceTransformer | |
| word_embedding_model = models.Transformer(model_name_or_path='nomic-ai/nomic-embed-text-v1.5', model=model, tokenizer=tokenizer) | |
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) | |
| encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | |
| logging.info("Successfully loaded the SentenceTransformer model.") | |
| except Exception as e: | |
| logging.error(f"Failed to load the SentenceTransformer model: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to load the SentenceTransformer model.") | |
| # Initialize the Qdrant searcher | |
| searcher = QdrantSearcher(encoder, qdrant_url, access_token) | |
| # Define the request body models | |
| class SearchDocumentsRequest(BaseModel): | |
| query: str | |
| limit: int = 3 | |
| class GenerateRAGRequest(BaseModel): | |
| search_query: str | |
| # Define the search documents endpoint | |
| async def search_documents( | |
| body: SearchDocumentsRequest, | |
| credentials: tuple = Depends(token_required) | |
| ): | |
| customer_id, user_id = credentials | |
| if not customer_id or not user_id: | |
| logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
| raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
| logging.info("Received request to search documents") | |
| try: | |
| hits, error = searcher.search_documents("documents", body.query, user_id, body.limit) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| return hits | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Define the generate RAG response endpoint | |
| async def generate_rag_response_api( | |
| body: GenerateRAGRequest, | |
| credentials: tuple = Depends(token_required) | |
| ): | |
| customer_id, user_id = credentials | |
| if not customer_id or not user_id: | |
| logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
| raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
| logging.info("Received request to generate RAG response") | |
| try: | |
| hits, error = searcher.search_documents("documents", body.search_query, user_id) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| response, error = generate_rag_response(hits, body.search_query) | |
| if error: | |
| logging.error(f"Generate RAG response error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| return {"response": response} | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == '__main__': | |
| import uvicorn | |
| uvicorn.run(app, host='0.0.0.0', port=8000) | |