import logging import os from typing import List, Optional from urllib.parse import unquote import google.generativeai as genai from adalflow.components.model_client.ollama_client import OllamaClient from adalflow.core.types import ModelType from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY from api.data_pipeline import count_tokens, get_file_content from api.openai_client import OpenAIClient from api.openrouter_client import OpenRouterClient from api.bedrock_client import BedrockClient from api.azureai_client import AzureAIClient from api.rag import RAG from api.prompts import ( DEEP_RESEARCH_FIRST_ITERATION_PROMPT, DEEP_RESEARCH_FINAL_ITERATION_PROMPT, DEEP_RESEARCH_INTERMEDIATE_ITERATION_PROMPT, SIMPLE_CHAT_SYSTEM_PROMPT ) # Configure logging from api.logging_config import setup_logging setup_logging() logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Simple Chat API", description="Simplified API for streaming chat completions" ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # Models for the API class ChatMessage(BaseModel): role: str # 'user' or 'assistant' content: str class ChatCompletionRequest(BaseModel): """ Model for requesting a chat completion. """ repo_url: str = Field(..., description="URL of the repository to query") messages: List[ChatMessage] = Field(..., description="List of chat messages") filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt") token: Optional[str] = Field(None, description="Personal access token for private repositories") type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')") # model parameters provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama, bedrock, azure)") model: Optional[str] = Field(None, description="Model name for the specified provider") language: Optional[str] = Field("en", description="Language for content generation (e.g., 'en', 'ja', 'zh', 'es', 'kr', 'vi')") excluded_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to exclude from processing") excluded_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to exclude from processing") included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") @app.post("/chat/completions/stream") async def chat_completions_stream(request: ChatCompletionRequest): """Stream a chat completion response directly using Google Generative AI""" try: # Check if request contains very large input input_too_large = False if request.messages and len(request.messages) > 0: last_message = request.messages[-1] if hasattr(last_message, 'content') and last_message.content: tokens = count_tokens(last_message.content, request.provider == "ollama") # Request size check if tokens > 8000: print(f"Warning: Request exceeds recommended token limit ({tokens} > 7500)") input_too_large = True # Create a new RAG instance for this request try: request_rag = RAG(provider=request.provider, model=request.model) # Extract custom file filter parameters if provided excluded_dirs = None excluded_files = None included_dirs = None included_files = None if request.excluded_dirs: excluded_dirs = [unquote(dir_path) for dir_path in request.excluded_dirs.split('\n') if dir_path.strip()] # Using custom excluded directories if request.excluded_files: excluded_files = [unquote(file_pattern) for file_pattern in request.excluded_files.split('\n') if file_pattern.strip()] # Using custom excluded files if request.included_dirs: included_dirs = [unquote(dir_path) for dir_path in request.included_dirs.split('\n') if dir_path.strip()] # Using custom included directories if request.included_files: included_files = [unquote(file_pattern) for file_pattern in request.included_files.split('\n') if file_pattern.strip()] # Using custom included files request_rag.prepare_retriever(request.repo_url, request.type, request.token, excluded_dirs, excluded_files, included_dirs, included_files) print(f"Retriever prepared for {request.repo_url}") except ValueError as e: if "No valid documents with embeddings found" in str(e): print(f"Error: No valid embeddings found: {str(e)}") raise HTTPException(status_code=500, detail="No valid document embeddings found. This may be due to embedding size inconsistencies or API errors during document processing. Please try again or check your repository content.") else: print(f"Error: ValueError preparing retriever: {str(e)}") raise HTTPException(status_code=500, detail=f"Error preparing retriever: {str(e)}") except Exception as e: print(f"Error preparing retriever: {str(e)}") # Check for specific embedding-related errors if "All embeddings should be of the same size" in str(e): raise HTTPException(status_code=500, detail="Inconsistent embedding sizes detected. Some documents may have failed to embed properly. Please try again.") else: raise HTTPException(status_code=500, detail=f"Error preparing retriever: {str(e)}") # Validate request if not request.messages or len(request.messages) == 0: raise HTTPException(status_code=400, detail="No messages provided") last_message = request.messages[-1] if last_message.role != "user": raise HTTPException(status_code=400, detail="Last message must be from the user") # Process previous messages to build conversation history for i in range(0, len(request.messages) - 1, 2): if i + 1 < len(request.messages): user_msg = request.messages[i] assistant_msg = request.messages[i + 1] if user_msg.role == "user" and assistant_msg.role == "assistant": request_rag.memory.add_dialog_turn( user_query=user_msg.content, assistant_response=assistant_msg.content ) # Check if this is a Deep Research request is_deep_research = False research_iteration = 1 # Process messages to detect Deep Research requests for msg in request.messages: if hasattr(msg, 'content') and msg.content and "[DEEP RESEARCH]" in msg.content: is_deep_research = True # Only remove the tag from the last message if msg == request.messages[-1]: # Remove the Deep Research tag msg.content = msg.content.replace("[DEEP RESEARCH]", "").strip() # Count research iterations if this is a Deep Research request if is_deep_research: research_iteration = sum(1 for msg in request.messages if msg.role == 'assistant') + 1 print(f"Deep Research request detected - iteration {research_iteration}") # Check if this is a continuation request if "continue" in last_message.content.lower() and "research" in last_message.content.lower(): # Find the original topic from the first user message original_topic = None for msg in request.messages: if msg.role == "user" and "continue" not in msg.content.lower(): original_topic = msg.content.replace("[DEEP RESEARCH]", "").strip() # Found original research topic break if original_topic: # Replace the continuation message with the original topic last_message.content = original_topic # Using original topic for research # Get the query from the last message query = last_message.content # Only retrieve documents if input is not too large context_text = "" retrieved_documents = None if not input_too_large: try: # If filePath exists, modify the query for RAG to focus on the file rag_query = query if request.filePath: # Use the file path to get relevant context about the file rag_query = f"Contexts related to {request.filePath}" # Modified RAG query to focus on file # Try to perform RAG retrieval try: # This will use the actual RAG implementation retrieved_documents = request_rag(rag_query, language=request.language) if retrieved_documents and retrieved_documents[0].documents: # Format context for the prompt in a more structured way documents = retrieved_documents[0].documents # Retrieved documents # Group documents by file path docs_by_file = {} for doc in documents: file_path = doc.meta_data.get('file_path', 'unknown') if file_path not in docs_by_file: docs_by_file[file_path] = [] docs_by_file[file_path].append(doc) # Format context text with file path grouping context_parts = [] for file_path, docs in docs_by_file.items(): # Add file header with metadata header = f"## File Path: {file_path}\n\n" # Add document content content = "\n\n".join([doc.text for doc in docs]) context_parts.append(f"{header}{content}") # Join all parts with clear separation context_text = "\n\n" + "-" * 10 + "\n\n".join(context_parts) else: print("Warning: No documents retrieved from RAG") except Exception as e: print(f"Error in RAG retrieval: {str(e)}") # Continue without RAG if there's an error except Exception as e: print(f"Error retrieving documents: {str(e)}") context_text = "" # Get repository information repo_url = request.repo_url repo_name = repo_url.split("/")[-1] if "/" in repo_url else repo_url # Determine repository type repo_type = request.type # Get language information language_code = request.language or configs["lang_config"]["default"] supported_langs = configs["lang_config"]["supported_languages"] language_name = supported_langs.get(language_code, "English") # Create system prompt if is_deep_research: # Check if this is the first iteration is_first_iteration = research_iteration == 1 # Check if this is the final iteration is_final_iteration = research_iteration >= 5 if is_first_iteration: system_prompt = DEEP_RESEARCH_FIRST_ITERATION_PROMPT.format( repo_type=repo_type, repo_url=repo_url, repo_name=repo_name, language_name=language_name ) elif is_final_iteration: system_prompt = DEEP_RESEARCH_FINAL_ITERATION_PROMPT.format( repo_type=repo_type, repo_url=repo_url, repo_name=repo_name, research_iteration=research_iteration, language_name=language_name ) else: system_prompt = DEEP_RESEARCH_INTERMEDIATE_ITERATION_PROMPT.format( repo_type=repo_type, repo_url=repo_url, repo_name=repo_name, research_iteration=research_iteration, language_name=language_name ) else: system_prompt = SIMPLE_CHAT_SYSTEM_PROMPT.format( repo_type=repo_type, repo_url=repo_url, repo_name=repo_name, language_name=language_name ) # Fetch file content if provided file_content = "" if request.filePath: try: file_content = get_file_content(request.repo_url, request.filePath, request.type, request.token) # Successfully retrieved content for file except Exception as e: print(f"Error retrieving file content: {str(e)}") # Continue without file content if there's an error # Format conversation history conversation_history = "" for turn_id, turn in request_rag.memory().items(): if not isinstance(turn_id, int) and hasattr(turn, 'user_query') and hasattr(turn, 'assistant_response'): conversation_history += f"\n{turn.user_query.query_str}\n{turn.assistant_response.response_str}\n\n" # Create the prompt with context prompt = f"/no_think {system_prompt}\n\n" if conversation_history: prompt += f"\n{conversation_history}\n\n" # Check if filePath is provided and fetch file content if it exists if file_content: # Add file content to the prompt after conversation history prompt += f"\n{file_content}\n\n\n" # Only include context if it's not empty CONTEXT_START = "" CONTEXT_END = "" if context_text.strip(): prompt += f"{CONTEXT_START}\n{context_text}\n{CONTEXT_END}\n\n" else: # Add a note that we're skipping RAG due to size constraints or because it's the isolated API # No context available from RAG prompt += "Answering without retrieval augmentation.\n\n" prompt += f"\n{query}\n\n\nAssistant: " model_config = get_model_config(request.provider, request.model)["model_kwargs"] if request.provider == "ollama": prompt += " /no_think" model = OllamaClient() model_kwargs = { "model": model_config["model"], "stream": True, "options": { "temperature": model_config["temperature"], "top_p": model_config["top_p"], "num_ctx": model_config["num_ctx"] } } api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) elif request.provider == "openrouter": # Using OpenRouter # Check if OpenRouter API key is set if not OPENROUTER_API_KEY: print("Warning: OPENROUTER_API_KEY not configured") # We'll let the OpenRouterClient handle this and return a friendly error message model = OpenRouterClient() model_kwargs = { "model": request.model, "stream": True, "temperature": model_config["temperature"] } # Only add top_p if it exists in the model config if "top_p" in model_config: model_kwargs["top_p"] = model_config["top_p"] api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) elif request.provider == "openai": # Using OpenAI # Check if an API key is set for Openai if not OPENAI_API_KEY: print("Warning: OPENAI_API_KEY not configured") # We'll let the OpenAIClient handle this and return an error message # Initialize Openai client model = OpenAIClient() model_kwargs = { "model": request.model, "stream": True, "temperature": model_config["temperature"] } # Only add top_p if it exists in the model config if "top_p" in model_config: model_kwargs["top_p"] = model_config["top_p"] api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) elif request.provider == "bedrock": # Using AWS Bedrock # Check if AWS credentials are set if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: print("Warning: AWS credentials not configured") # We'll let the BedrockClient handle this and return an error message # Initialize Bedrock client model = BedrockClient() model_kwargs = { "model": request.model, "temperature": model_config["temperature"], "top_p": model_config["top_p"] } api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) elif request.provider == "azure": # Using Azure AI # Initialize Azure AI client model = AzureAIClient() model_kwargs = { "model": request.model, "stream": True, "temperature": model_config["temperature"], "top_p": model_config["top_p"] } api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) else: # Initialize Google Generative AI model model = genai.GenerativeModel( model_name=model_config["model"], generation_config={ "temperature": model_config["temperature"], "top_p": model_config["top_p"], "top_k": model_config["top_k"] } ) # Create a streaming response async def response_stream(): try: if request.provider == "ollama": # Get the response and handle it properly using the previously created api_kwargs response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) # Handle streaming response from Ollama async for chunk in response: text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk) if text and not text.startswith('model=') and not text.startswith('created_at='): text = text.replace('', '').replace('', '') yield text elif request.provider == "openrouter": try: # Get the response and handle it properly using the previously created api_kwargs # Making OpenRouter API call response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) # Handle streaming response from OpenRouter async for chunk in response: yield chunk except Exception as e_openrouter: print(f"Error with OpenRouter API: {str(e_openrouter)}") yield f"\nError with OpenRouter API: {str(e_openrouter)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key." elif request.provider == "openai": try: # Get the response and handle it properly using the previously created api_kwargs # Making OpenAI API call response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) # Handle streaming response from Openai async for chunk in response: choices = getattr(chunk, "choices", []) if len(choices) > 0: delta = getattr(choices[0], "delta", None) if delta is not None: text = getattr(delta, "content", None) if text is not None: yield text except Exception as e_openai: print(f"Error with OpenAI API: {str(e_openai)}") yield f"\nError with Openai API: {str(e_openai)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key." elif request.provider == "bedrock": try: # Get the response and handle it properly using the previously created api_kwargs # Making AWS Bedrock API call response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) # Handle response from Bedrock (not streaming yet) if isinstance(response, str): yield response else: # Try to extract text from the response yield str(response) except Exception as e_bedrock: print(f"Error with AWS Bedrock API: {str(e_bedrock)}") yield f"\nError with AWS Bedrock API: {str(e_bedrock)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials." elif request.provider == "azure": try: # Get the response and handle it properly using the previously created api_kwargs # Making Azure AI API call response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) # Handle streaming response from Azure AI async for chunk in response: choices = getattr(chunk, "choices", []) if len(choices) > 0: delta = getattr(choices[0], "delta", None) if delta is not None: text = getattr(delta, "content", None) if text is not None: yield text except Exception as e_azure: print(f"Error with Azure AI API: {str(e_azure)}") yield f"\nError with Azure AI API: {str(e_azure)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values." else: # Generate streaming response response = model.generate_content(prompt, stream=True) # Stream the response for chunk in response: if hasattr(chunk, 'text'): yield chunk.text except Exception as e_outer: print(f"Error in streaming response: {str(e_outer)}") error_message = str(e_outer) # Check for token limit errors if "maximum context length" in error_message or "token limit" in error_message or "too many tokens" in error_message: # If we hit a token limit error, try again without context print("Warning: Token limit exceeded, retrying without context") try: # Create a simplified prompt without context simplified_prompt = f"/no_think {system_prompt}\n\n" if conversation_history: simplified_prompt += f"\n{conversation_history}\n\n" # Include file content in the fallback prompt if it was retrieved if request.filePath and file_content: simplified_prompt += f"\n{file_content}\n\n\n" simplified_prompt += "Answering without retrieval augmentation due to input size constraints.\n\n" simplified_prompt += f"\n{query}\n\n\nAssistant: " if request.provider == "ollama": simplified_prompt += " /no_think" # Create new api_kwargs with the simplified prompt fallback_api_kwargs = model.convert_inputs_to_api_kwargs( input=simplified_prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) # Get the response using the simplified prompt fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) # Handle streaming fallback_response from Ollama async for chunk in fallback_response: text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk) if text and not text.startswith('model=') and not text.startswith('created_at='): text = text.replace('', '').replace('', '') yield text elif request.provider == "openrouter": try: # Create new api_kwargs with the simplified prompt fallback_api_kwargs = model.convert_inputs_to_api_kwargs( input=simplified_prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) # Get the response using the simplified prompt # Making fallback OpenRouter API call fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) # Handle streaming fallback_response from OpenRouter async for chunk in fallback_response: yield chunk except Exception as e_fallback: print(f"Error with OpenRouter API fallback: {str(e_fallback)}") yield f"\nError with OpenRouter API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key." elif request.provider == "openai": try: # Create new api_kwargs with the simplified prompt fallback_api_kwargs = model.convert_inputs_to_api_kwargs( input=simplified_prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) # Get the response using the simplified prompt # Making fallback OpenAI API call fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) # Handle streaming fallback_response from Openai async for chunk in fallback_response: text = chunk if isinstance(chunk, str) else getattr(chunk, 'text', str(chunk)) yield text except Exception as e_fallback: print(f"Error with OpenAI API fallback: {str(e_fallback)}") yield f"\nError with Openai API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key." elif request.provider == "bedrock": try: # Create new api_kwargs with the simplified prompt fallback_api_kwargs = model.convert_inputs_to_api_kwargs( input=simplified_prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) # Get the response using the simplified prompt # Making fallback AWS Bedrock API call fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) # Handle response from Bedrock if isinstance(fallback_response, str): yield fallback_response else: # Try to extract text from the response yield str(fallback_response) except Exception as e_fallback: print(f"Error with AWS Bedrock API fallback: {str(e_fallback)}") yield f"\nError with AWS Bedrock API fallback: {str(e_fallback)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials." elif request.provider == "azure": try: # Create new api_kwargs with the simplified prompt fallback_api_kwargs = model.convert_inputs_to_api_kwargs( input=simplified_prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM ) # Get the response using the simplified prompt # Making fallback Azure AI API call fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) # Handle streaming fallback response from Azure AI async for chunk in fallback_response: choices = getattr(chunk, "choices", []) if len(choices) > 0: delta = getattr(choices[0], "delta", None) if delta is not None: text = getattr(delta, "content", None) if text is not None: yield text except Exception as e_fallback: print(f"Error with Azure AI API fallback: {str(e_fallback)}") yield f"\nError with Azure AI API fallback: {str(e_fallback)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values." else: # Initialize Google Generative AI model model_config = get_model_config(request.provider, request.model) fallback_model = genai.GenerativeModel( model_name=model_config["model"], generation_config={ "temperature": model_config["model_kwargs"].get("temperature", 0.7), "top_p": model_config["model_kwargs"].get("top_p", 0.8), "top_k": model_config["model_kwargs"].get("top_k", 40) } ) # Get streaming response using simplified prompt fallback_response = fallback_model.generate_content(simplified_prompt, stream=True) # Stream the fallback response for chunk in fallback_response: if hasattr(chunk, 'text'): yield chunk.text except Exception as e2: print(f"Error in fallback streaming response: {str(e2)}") yield f"\nI apologize, but your request is too large for me to process. Please try a shorter query or break it into smaller parts." else: # For other errors, return the error message yield f"\nError: {error_message}" # Return streaming response return StreamingResponse(response_stream(), media_type="text/event-stream") except HTTPException: raise except Exception as e_handler: error_msg = f"Error in streaming chat completion: {str(e_handler)}" print(f"Error: {error_msg}") raise HTTPException(status_code=500, detail=error_msg) @app.get("/") async def root(): """Root endpoint to check if the API is running""" return {"status": "API is running", "message": "Navigate to /docs for API documentation"}