Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from typing import List, Optional | |
| from urllib.parse import unquote | |
| import google.generativeai as genai | |
| from adalflow.components.model_client.ollama_client import OllamaClient | |
| from adalflow.core.types import ModelType | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY | |
| from api.data_pipeline import count_tokens, get_file_content | |
| from api.openai_client import OpenAIClient | |
| from api.openrouter_client import OpenRouterClient | |
| from api.bedrock_client import BedrockClient | |
| from api.azureai_client import AzureAIClient | |
| from api.rag import RAG | |
| from api.prompts import ( | |
| DEEP_RESEARCH_FIRST_ITERATION_PROMPT, | |
| DEEP_RESEARCH_FINAL_ITERATION_PROMPT, | |
| DEEP_RESEARCH_INTERMEDIATE_ITERATION_PROMPT, | |
| SIMPLE_CHAT_SYSTEM_PROMPT | |
| ) | |
| # Configure logging | |
| from api.logging_config import setup_logging | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Simple Chat API", | |
| description="Simplified API for streaming chat completions" | |
| ) | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Models for the API | |
| class ChatMessage(BaseModel): | |
| role: str # 'user' or 'assistant' | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| """ | |
| Model for requesting a chat completion. | |
| """ | |
| repo_url: str = Field(..., description="URL of the repository to query") | |
| messages: List[ChatMessage] = Field(..., description="List of chat messages") | |
| filePath: Optional[str] = Field(None, description="Optional path to a file in the repository to include in the prompt") | |
| token: Optional[str] = Field(None, description="Personal access token for private repositories") | |
| type: Optional[str] = Field("github", description="Type of repository (e.g., 'github', 'gitlab', 'bitbucket')") | |
| # model parameters | |
| provider: str = Field("google", description="Model provider (google, openai, openrouter, ollama, bedrock, azure)") | |
| model: Optional[str] = Field(None, description="Model name for the specified provider") | |
| language: Optional[str] = Field("en", description="Language for content generation (e.g., 'en', 'ja', 'zh', 'es', 'kr', 'vi')") | |
| excluded_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to exclude from processing") | |
| excluded_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to exclude from processing") | |
| included_dirs: Optional[str] = Field(None, description="Comma-separated list of directories to include exclusively") | |
| included_files: Optional[str] = Field(None, description="Comma-separated list of file patterns to include exclusively") | |
| async def chat_completions_stream(request: ChatCompletionRequest): | |
| """Stream a chat completion response directly using Google Generative AI""" | |
| try: | |
| # Check if request contains very large input | |
| input_too_large = False | |
| if request.messages and len(request.messages) > 0: | |
| last_message = request.messages[-1] | |
| if hasattr(last_message, 'content') and last_message.content: | |
| tokens = count_tokens(last_message.content, request.provider == "ollama") | |
| # Request size check | |
| if tokens > 8000: | |
| print(f"Warning: Request exceeds recommended token limit ({tokens} > 7500)") | |
| input_too_large = True | |
| # Create a new RAG instance for this request | |
| try: | |
| request_rag = RAG(provider=request.provider, model=request.model) | |
| # Extract custom file filter parameters if provided | |
| excluded_dirs = None | |
| excluded_files = None | |
| included_dirs = None | |
| included_files = None | |
| if request.excluded_dirs: | |
| excluded_dirs = [unquote(dir_path) for dir_path in request.excluded_dirs.split('\n') if dir_path.strip()] | |
| # Using custom excluded directories | |
| if request.excluded_files: | |
| excluded_files = [unquote(file_pattern) for file_pattern in request.excluded_files.split('\n') if file_pattern.strip()] | |
| # Using custom excluded files | |
| if request.included_dirs: | |
| included_dirs = [unquote(dir_path) for dir_path in request.included_dirs.split('\n') if dir_path.strip()] | |
| # Using custom included directories | |
| if request.included_files: | |
| included_files = [unquote(file_pattern) for file_pattern in request.included_files.split('\n') if file_pattern.strip()] | |
| # Using custom included files | |
| request_rag.prepare_retriever(request.repo_url, request.type, request.token, excluded_dirs, excluded_files, included_dirs, included_files) | |
| print(f"Retriever prepared for {request.repo_url}") | |
| except ValueError as e: | |
| if "No valid documents with embeddings found" in str(e): | |
| print(f"Error: No valid embeddings found: {str(e)}") | |
| raise HTTPException(status_code=500, detail="No valid document embeddings found. This may be due to embedding size inconsistencies or API errors during document processing. Please try again or check your repository content.") | |
| else: | |
| print(f"Error: ValueError preparing retriever: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error preparing retriever: {str(e)}") | |
| except Exception as e: | |
| print(f"Error preparing retriever: {str(e)}") | |
| # Check for specific embedding-related errors | |
| if "All embeddings should be of the same size" in str(e): | |
| raise HTTPException(status_code=500, detail="Inconsistent embedding sizes detected. Some documents may have failed to embed properly. Please try again.") | |
| else: | |
| raise HTTPException(status_code=500, detail=f"Error preparing retriever: {str(e)}") | |
| # Validate request | |
| if not request.messages or len(request.messages) == 0: | |
| raise HTTPException(status_code=400, detail="No messages provided") | |
| last_message = request.messages[-1] | |
| if last_message.role != "user": | |
| raise HTTPException(status_code=400, detail="Last message must be from the user") | |
| # Process previous messages to build conversation history | |
| for i in range(0, len(request.messages) - 1, 2): | |
| if i + 1 < len(request.messages): | |
| user_msg = request.messages[i] | |
| assistant_msg = request.messages[i + 1] | |
| if user_msg.role == "user" and assistant_msg.role == "assistant": | |
| request_rag.memory.add_dialog_turn( | |
| user_query=user_msg.content, | |
| assistant_response=assistant_msg.content | |
| ) | |
| # Check if this is a Deep Research request | |
| is_deep_research = False | |
| research_iteration = 1 | |
| # Process messages to detect Deep Research requests | |
| for msg in request.messages: | |
| if hasattr(msg, 'content') and msg.content and "[DEEP RESEARCH]" in msg.content: | |
| is_deep_research = True | |
| # Only remove the tag from the last message | |
| if msg == request.messages[-1]: | |
| # Remove the Deep Research tag | |
| msg.content = msg.content.replace("[DEEP RESEARCH]", "").strip() | |
| # Count research iterations if this is a Deep Research request | |
| if is_deep_research: | |
| research_iteration = sum(1 for msg in request.messages if msg.role == 'assistant') + 1 | |
| print(f"Deep Research request detected - iteration {research_iteration}") | |
| # Check if this is a continuation request | |
| if "continue" in last_message.content.lower() and "research" in last_message.content.lower(): | |
| # Find the original topic from the first user message | |
| original_topic = None | |
| for msg in request.messages: | |
| if msg.role == "user" and "continue" not in msg.content.lower(): | |
| original_topic = msg.content.replace("[DEEP RESEARCH]", "").strip() | |
| # Found original research topic | |
| break | |
| if original_topic: | |
| # Replace the continuation message with the original topic | |
| last_message.content = original_topic | |
| # Using original topic for research | |
| # Get the query from the last message | |
| query = last_message.content | |
| # Only retrieve documents if input is not too large | |
| context_text = "" | |
| retrieved_documents = None | |
| if not input_too_large: | |
| try: | |
| # If filePath exists, modify the query for RAG to focus on the file | |
| rag_query = query | |
| if request.filePath: | |
| # Use the file path to get relevant context about the file | |
| rag_query = f"Contexts related to {request.filePath}" | |
| # Modified RAG query to focus on file | |
| # Try to perform RAG retrieval | |
| try: | |
| # This will use the actual RAG implementation | |
| retrieved_documents = request_rag(rag_query, language=request.language) | |
| if retrieved_documents and retrieved_documents[0].documents: | |
| # Format context for the prompt in a more structured way | |
| documents = retrieved_documents[0].documents | |
| # Retrieved documents | |
| # Group documents by file path | |
| docs_by_file = {} | |
| for doc in documents: | |
| file_path = doc.meta_data.get('file_path', 'unknown') | |
| if file_path not in docs_by_file: | |
| docs_by_file[file_path] = [] | |
| docs_by_file[file_path].append(doc) | |
| # Format context text with file path grouping | |
| context_parts = [] | |
| for file_path, docs in docs_by_file.items(): | |
| # Add file header with metadata | |
| header = f"## File Path: {file_path}\n\n" | |
| # Add document content | |
| content = "\n\n".join([doc.text for doc in docs]) | |
| context_parts.append(f"{header}{content}") | |
| # Join all parts with clear separation | |
| context_text = "\n\n" + "-" * 10 + "\n\n".join(context_parts) | |
| else: | |
| print("Warning: No documents retrieved from RAG") | |
| except Exception as e: | |
| print(f"Error in RAG retrieval: {str(e)}") | |
| # Continue without RAG if there's an error | |
| except Exception as e: | |
| print(f"Error retrieving documents: {str(e)}") | |
| context_text = "" | |
| # Get repository information | |
| repo_url = request.repo_url | |
| repo_name = repo_url.split("/")[-1] if "/" in repo_url else repo_url | |
| # Determine repository type | |
| repo_type = request.type | |
| # Get language information | |
| language_code = request.language or configs["lang_config"]["default"] | |
| supported_langs = configs["lang_config"]["supported_languages"] | |
| language_name = supported_langs.get(language_code, "English") | |
| # Create system prompt | |
| if is_deep_research: | |
| # Check if this is the first iteration | |
| is_first_iteration = research_iteration == 1 | |
| # Check if this is the final iteration | |
| is_final_iteration = research_iteration >= 5 | |
| if is_first_iteration: | |
| system_prompt = DEEP_RESEARCH_FIRST_ITERATION_PROMPT.format( | |
| repo_type=repo_type, | |
| repo_url=repo_url, | |
| repo_name=repo_name, | |
| language_name=language_name | |
| ) | |
| elif is_final_iteration: | |
| system_prompt = DEEP_RESEARCH_FINAL_ITERATION_PROMPT.format( | |
| repo_type=repo_type, | |
| repo_url=repo_url, | |
| repo_name=repo_name, | |
| research_iteration=research_iteration, | |
| language_name=language_name | |
| ) | |
| else: | |
| system_prompt = DEEP_RESEARCH_INTERMEDIATE_ITERATION_PROMPT.format( | |
| repo_type=repo_type, | |
| repo_url=repo_url, | |
| repo_name=repo_name, | |
| research_iteration=research_iteration, | |
| language_name=language_name | |
| ) | |
| else: | |
| system_prompt = SIMPLE_CHAT_SYSTEM_PROMPT.format( | |
| repo_type=repo_type, | |
| repo_url=repo_url, | |
| repo_name=repo_name, | |
| language_name=language_name | |
| ) | |
| # Fetch file content if provided | |
| file_content = "" | |
| if request.filePath: | |
| try: | |
| file_content = get_file_content(request.repo_url, request.filePath, request.type, request.token) | |
| # Successfully retrieved content for file | |
| except Exception as e: | |
| print(f"Error retrieving file content: {str(e)}") | |
| # Continue without file content if there's an error | |
| # Format conversation history | |
| conversation_history = "" | |
| for turn_id, turn in request_rag.memory().items(): | |
| if not isinstance(turn_id, int) and hasattr(turn, 'user_query') and hasattr(turn, 'assistant_response'): | |
| conversation_history += f"<turn>\n<user>{turn.user_query.query_str}</user>\n<assistant>{turn.assistant_response.response_str}</assistant>\n</turn>\n" | |
| # Create the prompt with context | |
| prompt = f"/no_think {system_prompt}\n\n" | |
| if conversation_history: | |
| prompt += f"<conversation_history>\n{conversation_history}</conversation_history>\n\n" | |
| # Check if filePath is provided and fetch file content if it exists | |
| if file_content: | |
| # Add file content to the prompt after conversation history | |
| prompt += f"<currentFileContent path=\"{request.filePath}\">\n{file_content}\n</currentFileContent>\n\n" | |
| # Only include context if it's not empty | |
| CONTEXT_START = "<START_OF_CONTEXT>" | |
| CONTEXT_END = "<END_OF_CONTEXT>" | |
| if context_text.strip(): | |
| prompt += f"{CONTEXT_START}\n{context_text}\n{CONTEXT_END}\n\n" | |
| else: | |
| # Add a note that we're skipping RAG due to size constraints or because it's the isolated API | |
| # No context available from RAG | |
| prompt += "<note>Answering without retrieval augmentation.</note>\n\n" | |
| prompt += f"<query>\n{query}\n</query>\n\nAssistant: " | |
| model_config = get_model_config(request.provider, request.model)["model_kwargs"] | |
| if request.provider == "ollama": | |
| prompt += " /no_think" | |
| model = OllamaClient() | |
| model_kwargs = { | |
| "model": model_config["model"], | |
| "stream": True, | |
| "options": { | |
| "temperature": model_config["temperature"], | |
| "top_p": model_config["top_p"], | |
| "num_ctx": model_config["num_ctx"] | |
| } | |
| } | |
| api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| elif request.provider == "openrouter": | |
| # Using OpenRouter | |
| # Check if OpenRouter API key is set | |
| if not OPENROUTER_API_KEY: | |
| print("Warning: OPENROUTER_API_KEY not configured") | |
| # We'll let the OpenRouterClient handle this and return a friendly error message | |
| model = OpenRouterClient() | |
| model_kwargs = { | |
| "model": request.model, | |
| "stream": True, | |
| "temperature": model_config["temperature"] | |
| } | |
| # Only add top_p if it exists in the model config | |
| if "top_p" in model_config: | |
| model_kwargs["top_p"] = model_config["top_p"] | |
| api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| elif request.provider == "openai": | |
| # Using OpenAI | |
| # Check if an API key is set for Openai | |
| if not OPENAI_API_KEY: | |
| print("Warning: OPENAI_API_KEY not configured") | |
| # We'll let the OpenAIClient handle this and return an error message | |
| # Initialize Openai client | |
| model = OpenAIClient() | |
| model_kwargs = { | |
| "model": request.model, | |
| "stream": True, | |
| "temperature": model_config["temperature"] | |
| } | |
| # Only add top_p if it exists in the model config | |
| if "top_p" in model_config: | |
| model_kwargs["top_p"] = model_config["top_p"] | |
| api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| elif request.provider == "bedrock": | |
| # Using AWS Bedrock | |
| # Check if AWS credentials are set | |
| if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: | |
| print("Warning: AWS credentials not configured") | |
| # We'll let the BedrockClient handle this and return an error message | |
| # Initialize Bedrock client | |
| model = BedrockClient() | |
| model_kwargs = { | |
| "model": request.model, | |
| "temperature": model_config["temperature"], | |
| "top_p": model_config["top_p"] | |
| } | |
| api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| elif request.provider == "azure": | |
| # Using Azure AI | |
| # Initialize Azure AI client | |
| model = AzureAIClient() | |
| model_kwargs = { | |
| "model": request.model, | |
| "stream": True, | |
| "temperature": model_config["temperature"], | |
| "top_p": model_config["top_p"] | |
| } | |
| api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| else: | |
| # Initialize Google Generative AI model | |
| model = genai.GenerativeModel( | |
| model_name=model_config["model"], | |
| generation_config={ | |
| "temperature": model_config["temperature"], | |
| "top_p": model_config["top_p"], | |
| "top_k": model_config["top_k"] | |
| } | |
| ) | |
| # Create a streaming response | |
| async def response_stream(): | |
| try: | |
| if request.provider == "ollama": | |
| # Get the response and handle it properly using the previously created api_kwargs | |
| response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming response from Ollama | |
| async for chunk in response: | |
| text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk) | |
| if text and not text.startswith('model=') and not text.startswith('created_at='): | |
| text = text.replace('<think>', '').replace('</think>', '') | |
| yield text | |
| elif request.provider == "openrouter": | |
| try: | |
| # Get the response and handle it properly using the previously created api_kwargs | |
| # Making OpenRouter API call | |
| response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming response from OpenRouter | |
| async for chunk in response: | |
| yield chunk | |
| except Exception as e_openrouter: | |
| print(f"Error with OpenRouter API: {str(e_openrouter)}") | |
| yield f"\nError with OpenRouter API: {str(e_openrouter)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key." | |
| elif request.provider == "openai": | |
| try: | |
| # Get the response and handle it properly using the previously created api_kwargs | |
| # Making OpenAI API call | |
| response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming response from Openai | |
| async for chunk in response: | |
| choices = getattr(chunk, "choices", []) | |
| if len(choices) > 0: | |
| delta = getattr(choices[0], "delta", None) | |
| if delta is not None: | |
| text = getattr(delta, "content", None) | |
| if text is not None: | |
| yield text | |
| except Exception as e_openai: | |
| print(f"Error with OpenAI API: {str(e_openai)}") | |
| yield f"\nError with Openai API: {str(e_openai)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key." | |
| elif request.provider == "bedrock": | |
| try: | |
| # Get the response and handle it properly using the previously created api_kwargs | |
| # Making AWS Bedrock API call | |
| response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) | |
| # Handle response from Bedrock (not streaming yet) | |
| if isinstance(response, str): | |
| yield response | |
| else: | |
| # Try to extract text from the response | |
| yield str(response) | |
| except Exception as e_bedrock: | |
| print(f"Error with AWS Bedrock API: {str(e_bedrock)}") | |
| yield f"\nError with AWS Bedrock API: {str(e_bedrock)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials." | |
| elif request.provider == "azure": | |
| try: | |
| # Get the response and handle it properly using the previously created api_kwargs | |
| # Making Azure AI API call | |
| response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming response from Azure AI | |
| async for chunk in response: | |
| choices = getattr(chunk, "choices", []) | |
| if len(choices) > 0: | |
| delta = getattr(choices[0], "delta", None) | |
| if delta is not None: | |
| text = getattr(delta, "content", None) | |
| if text is not None: | |
| yield text | |
| except Exception as e_azure: | |
| print(f"Error with Azure AI API: {str(e_azure)}") | |
| yield f"\nError with Azure AI API: {str(e_azure)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values." | |
| else: | |
| # Generate streaming response | |
| response = model.generate_content(prompt, stream=True) | |
| # Stream the response | |
| for chunk in response: | |
| if hasattr(chunk, 'text'): | |
| yield chunk.text | |
| except Exception as e_outer: | |
| print(f"Error in streaming response: {str(e_outer)}") | |
| error_message = str(e_outer) | |
| # Check for token limit errors | |
| if "maximum context length" in error_message or "token limit" in error_message or "too many tokens" in error_message: | |
| # If we hit a token limit error, try again without context | |
| print("Warning: Token limit exceeded, retrying without context") | |
| try: | |
| # Create a simplified prompt without context | |
| simplified_prompt = f"/no_think {system_prompt}\n\n" | |
| if conversation_history: | |
| simplified_prompt += f"<conversation_history>\n{conversation_history}</conversation_history>\n\n" | |
| # Include file content in the fallback prompt if it was retrieved | |
| if request.filePath and file_content: | |
| simplified_prompt += f"<currentFileContent path=\"{request.filePath}\">\n{file_content}\n</currentFileContent>\n\n" | |
| simplified_prompt += "<note>Answering without retrieval augmentation due to input size constraints.</note>\n\n" | |
| simplified_prompt += f"<query>\n{query}\n</query>\n\nAssistant: " | |
| if request.provider == "ollama": | |
| simplified_prompt += " /no_think" | |
| # Create new api_kwargs with the simplified prompt | |
| fallback_api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=simplified_prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| # Get the response using the simplified prompt | |
| fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming fallback_response from Ollama | |
| async for chunk in fallback_response: | |
| text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk) | |
| if text and not text.startswith('model=') and not text.startswith('created_at='): | |
| text = text.replace('<think>', '').replace('</think>', '') | |
| yield text | |
| elif request.provider == "openrouter": | |
| try: | |
| # Create new api_kwargs with the simplified prompt | |
| fallback_api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=simplified_prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| # Get the response using the simplified prompt | |
| # Making fallback OpenRouter API call | |
| fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming fallback_response from OpenRouter | |
| async for chunk in fallback_response: | |
| yield chunk | |
| except Exception as e_fallback: | |
| print(f"Error with OpenRouter API fallback: {str(e_fallback)}") | |
| yield f"\nError with OpenRouter API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key." | |
| elif request.provider == "openai": | |
| try: | |
| # Create new api_kwargs with the simplified prompt | |
| fallback_api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=simplified_prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| # Get the response using the simplified prompt | |
| # Making fallback OpenAI API call | |
| fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming fallback_response from Openai | |
| async for chunk in fallback_response: | |
| text = chunk if isinstance(chunk, str) else getattr(chunk, 'text', str(chunk)) | |
| yield text | |
| except Exception as e_fallback: | |
| print(f"Error with OpenAI API fallback: {str(e_fallback)}") | |
| yield f"\nError with Openai API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key." | |
| elif request.provider == "bedrock": | |
| try: | |
| # Create new api_kwargs with the simplified prompt | |
| fallback_api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=simplified_prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| # Get the response using the simplified prompt | |
| # Making fallback AWS Bedrock API call | |
| fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) | |
| # Handle response from Bedrock | |
| if isinstance(fallback_response, str): | |
| yield fallback_response | |
| else: | |
| # Try to extract text from the response | |
| yield str(fallback_response) | |
| except Exception as e_fallback: | |
| print(f"Error with AWS Bedrock API fallback: {str(e_fallback)}") | |
| yield f"\nError with AWS Bedrock API fallback: {str(e_fallback)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials." | |
| elif request.provider == "azure": | |
| try: | |
| # Create new api_kwargs with the simplified prompt | |
| fallback_api_kwargs = model.convert_inputs_to_api_kwargs( | |
| input=simplified_prompt, | |
| model_kwargs=model_kwargs, | |
| model_type=ModelType.LLM | |
| ) | |
| # Get the response using the simplified prompt | |
| # Making fallback Azure AI API call | |
| fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) | |
| # Handle streaming fallback response from Azure AI | |
| async for chunk in fallback_response: | |
| choices = getattr(chunk, "choices", []) | |
| if len(choices) > 0: | |
| delta = getattr(choices[0], "delta", None) | |
| if delta is not None: | |
| text = getattr(delta, "content", None) | |
| if text is not None: | |
| yield text | |
| except Exception as e_fallback: | |
| print(f"Error with Azure AI API fallback: {str(e_fallback)}") | |
| yield f"\nError with Azure AI API fallback: {str(e_fallback)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values." | |
| else: | |
| # Initialize Google Generative AI model | |
| model_config = get_model_config(request.provider, request.model) | |
| fallback_model = genai.GenerativeModel( | |
| model_name=model_config["model"], | |
| generation_config={ | |
| "temperature": model_config["model_kwargs"].get("temperature", 0.7), | |
| "top_p": model_config["model_kwargs"].get("top_p", 0.8), | |
| "top_k": model_config["model_kwargs"].get("top_k", 40) | |
| } | |
| ) | |
| # Get streaming response using simplified prompt | |
| fallback_response = fallback_model.generate_content(simplified_prompt, stream=True) | |
| # Stream the fallback response | |
| for chunk in fallback_response: | |
| if hasattr(chunk, 'text'): | |
| yield chunk.text | |
| except Exception as e2: | |
| print(f"Error in fallback streaming response: {str(e2)}") | |
| yield f"\nI apologize, but your request is too large for me to process. Please try a shorter query or break it into smaller parts." | |
| else: | |
| # For other errors, return the error message | |
| yield f"\nError: {error_message}" | |
| # Return streaming response | |
| return StreamingResponse(response_stream(), media_type="text/event-stream") | |
| except HTTPException: | |
| raise | |
| except Exception as e_handler: | |
| error_msg = f"Error in streaming chat completion: {str(e_handler)}" | |
| print(f"Error: {error_msg}") | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| async def root(): | |
| """Root endpoint to check if the API is running""" | |
| return {"status": "API is running", "message": "Navigate to /docs for API documentation"} | |