Spaces:
Runtime error
Runtime error
| """ | |
| Fast API endpoint for RAG system - optimized for <3 second responses | |
| """ | |
| import modal | |
| app = modal.App("insurance-rag-api") | |
| # Reference your specific volume | |
| vol = modal.Volume.from_name("mcp-hack-ins-products", create_if_missing=True) | |
| # Model configuration | |
| LLM_MODEL = "microsoft/Phi-3-mini-4k-instruct" | |
| EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" | |
| # Build image with dependencies | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .pip_install( | |
| # Core ML dependencies (compatible versions) | |
| "torch>=2.0.0", | |
| "transformers>=4.30.0", | |
| "sentence-transformers>=2.2.0", | |
| "huggingface_hub>=0.15.0", | |
| # LangChain (compatible versions) | |
| "langchain>=0.1.0", | |
| "langchain-community>=0.0.13", | |
| # Document processing | |
| "pypdf>=4.0.0", | |
| "python-docx>=1.1.0", | |
| "openpyxl>=3.1.0", | |
| "pandas>=2.0.0", | |
| "xlrd>=2.0.0", | |
| # Vector database | |
| "chromadb>=0.4.0", | |
| # Web framework | |
| "fastapi>=0.100.0", | |
| "uvicorn[standard]>=0.20.0", | |
| # LLM inference (vLLM - latest stable) | |
| "vllm>=0.4.0", | |
| # Utilities | |
| "cryptography>=41.0.0", | |
| ) | |
| ) | |
| class FastRAGService: | |
| """Optimized RAG service for fast API responses""" | |
| def enter(self): | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from vllm import LLM, SamplingParams | |
| from langchain.schema import Document | |
| print("π Initializing Fast RAG Service...") | |
| # Initialize embeddings (faster model) | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cuda'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| # Connect to Chroma | |
| self.chroma_service = modal.Cls.from_name("chroma-server-v2", "ChromaDB")() | |
| # Custom retriever | |
| class RemoteChromaRetriever: | |
| def __init__(self, chroma_service, embeddings, k=5): | |
| self.chroma_service = chroma_service | |
| self.embeddings = embeddings | |
| self.k = k | |
| def get_relevant_documents(self, query: str): | |
| query_embedding = self.embeddings.embed_query(query) | |
| results = self.chroma_service.query.remote( | |
| collection_name="product_design", | |
| query_embeddings=[query_embedding], | |
| n_results=self.k | |
| ) | |
| docs = [] | |
| if results and 'documents' in results and len(results['documents']) > 0: | |
| for i, doc_text in enumerate(results['documents'][0]): | |
| metadata = results.get('metadatas', [[{}]])[0][i] if 'metadatas' in results else {} | |
| docs.append(Document(page_content=doc_text, metadata=metadata)) | |
| return docs | |
| self.Retriever = RemoteChromaRetriever | |
| # Load LLM with optimized settings for speed | |
| print(" Loading LLM (optimized for speed)...") | |
| self.llm_engine = LLM( | |
| model=LLM_MODEL, | |
| dtype="float16", | |
| gpu_memory_utilization=0.9, # Higher utilization for speed | |
| max_model_len=4096, | |
| trust_remote_code=True, | |
| enforce_eager=True, | |
| enable_prefix_caching=True, # Cache prefixes for faster generation | |
| ) | |
| # Optimized sampling params for speed | |
| self.default_sampling_params = SamplingParams( | |
| temperature=0.7, | |
| max_tokens=1024, # Reduced from 1536 for faster responses | |
| top_p=0.9, | |
| stop=["\n\n\n", "Question:", "Context:", "<|end|>"] | |
| ) | |
| print("β Fast RAG Service ready!") | |
| def query(self, question: str, top_k: int = 5, max_tokens: int = 1024): | |
| """Fast query method optimized for <3 second responses""" | |
| import time | |
| start_time = time.time() | |
| # Retrieve documents | |
| retrieval_start = time.time() | |
| retriever = self.Retriever( | |
| chroma_service=self.chroma_service, | |
| embeddings=self.embeddings, | |
| k=top_k | |
| ) | |
| docs = retriever.get_relevant_documents(question) | |
| retrieval_time = time.time() - retrieval_start | |
| if not docs: | |
| return { | |
| "answer": "No relevant information found in the product design document.", | |
| "retrieval_time": retrieval_time, | |
| "generation_time": 0, | |
| "total_time": time.time() - start_time, | |
| "sources": [], | |
| "success": False | |
| } | |
| # Build context (limit size for speed) | |
| context = "\n\n".join([doc.page_content[:800] for doc in docs[:3]]) # Limit to top 3 docs, 800 chars each | |
| # Create prompt | |
| prompt = f"""<|system|> | |
| You are a helpful AI assistant. Answer questions about the TokyoDrive Insurance product design document concisely and accurately.<|end|> | |
| <|user|> | |
| Context: | |
| {context} | |
| Question: | |
| {question}<|end|> | |
| <|assistant|>""" | |
| # Generate with optimized params | |
| from vllm import SamplingParams | |
| sampling_params = SamplingParams( | |
| temperature=0.7, | |
| max_tokens=max_tokens, | |
| top_p=0.9, | |
| stop=["\n\n\n", "Question:", "Context:", "<|end|>"] | |
| ) | |
| gen_start = time.time() | |
| outputs = self.llm_engine.generate(prompts=[prompt], sampling_params=sampling_params) | |
| answer = outputs[0].outputs[0].text.strip() | |
| generation_time = time.time() - gen_start | |
| # Prepare sources (limited for speed) | |
| sources = [] | |
| for doc in docs[:3]: # Limit to 3 sources | |
| sources.append({ | |
| "content": doc.page_content[:300], | |
| "metadata": doc.metadata | |
| }) | |
| total_time = time.time() - start_time | |
| return { | |
| "answer": answer, | |
| "retrieval_time": retrieval_time, | |
| "generation_time": generation_time, | |
| "total_time": total_time, | |
| "sources": sources, | |
| "success": True | |
| } | |
| # Deploy as web endpoint | |
| def fastapi_app(): | |
| """Deploy FastAPI app - all imports inside to avoid local dependency issues""" | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| # Request/Response models | |
| class QueryRequest(BaseModel): | |
| question: str | |
| top_k: int = 5 | |
| max_tokens: int = 1024 # Reduced for faster responses | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| retrieval_time: float | |
| generation_time: float | |
| total_time: float | |
| sources: list | |
| success: bool | |
| # FastAPI app | |
| web_app = FastAPI(title="Product Design RAG API", version="1.0.0") | |
| # CORS | |
| web_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize RAG service | |
| rag_service = FastRAGService() | |
| async def health(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "service": "rag-api"} | |
| async def query_rag(request: QueryRequest): | |
| """ | |
| Query the RAG system - optimized for <3 second responses | |
| Args: | |
| question: The question to ask | |
| top_k: Number of documents to retrieve (default: 5) | |
| max_tokens: Maximum tokens in response (default: 1024) | |
| Returns: | |
| QueryResponse with answer, timing, and sources | |
| """ | |
| try: | |
| result = rag_service.query.remote( | |
| question=request.question, | |
| top_k=request.top_k, | |
| max_tokens=request.max_tokens | |
| ) | |
| if not result.get("success", True): | |
| raise HTTPException(status_code=404, detail="No relevant information found") | |
| return QueryResponse(**result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| async def root(): | |
| """API root endpoint""" | |
| return { | |
| "service": "Product Design RAG API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "query": "/query (POST)" | |
| }, | |
| "target_response_time": "<3 seconds" | |
| } | |
| return web_app | |