Asish Karthikeya Gogineni
Fix: Removed duplicate chat method causing HTML leakage
3f9d83d
from typing import List, Tuple, Any, Optional
import logging
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.retrievers import BaseRetriever
# Simplified implementation that works with current langchain version
# We'll implement history-aware retrieval manually
from code_chatbot.retrieval.reranker import Reranker
from code_chatbot.retrieval.retriever_wrapper import build_enhanced_retriever
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Gemini models fallback list (tried in order)
GEMINI_FALLBACK_MODELS = [
"gemini-3-flash-preview",
"gemini-3-pro-preview",
"gemini-2.5-flash",
"gemini-2.5-pro",
"gemini-2.5-flash-preview-09-2025",
"gemini-2.5-flash-lite",
"gemini-2.5-flash-lite-preview-09-2025",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-pro",
]
class ChatEngine:
def __init__(
self,
retriever: BaseRetriever,
model_name: str = "gpt-4o",
provider: str = "openai",
api_key: str = None,
repo_name: Optional[str] = None,
use_agent: bool = True,
use_multi_query: bool = False,
use_reranking: bool = True,
repo_files: Optional[List[str]] = None,
repo_dir: str = ".", # New Argument
):
self.base_retriever = retriever
self.model_name = model_name
self.provider = provider
self.api_key = api_key
self.repo_name = repo_name or "codebase"
self.use_agent = use_agent
self.use_multi_query = use_multi_query
self.use_reranking = use_reranking
self.repo_files = repo_files
self.repo_dir = repo_dir
# Track current model index for fallback
self._gemini_model_index = 0
# Initialize LLM
self.llm = self._get_llm()
# Initialize conversation history
self.chat_history = []
# Build enhanced vector retriever
self.vector_retriever = build_enhanced_retriever(
base_retriever=retriever,
llm=self.llm if use_multi_query else None, # Only for query expansion
use_multi_query=use_multi_query,
use_reranking=use_reranking,
)
# Initialize LLM Retriever if files are available
self.llm_retriever = None
if self.repo_files:
try:
from code_chatbot.retrieval.llm_retriever import LLMRetriever
from langchain.retrievers import EnsembleRetriever
logger.info(f"Initializing LLMRetriever with {len(self.repo_files)} files.")
self.llm_retriever = LLMRetriever(
llm=self.llm,
repo_files=self.repo_files,
top_k=3
)
# Combine retrievers
self.retriever = EnsembleRetriever(
retrievers=[self.vector_retriever, self.llm_retriever],
weights=[0.6, 0.4]
)
except ImportError as e:
logger.warning(f"Could not load EnsembleRetriever or LLMRetriever: {e}")
self.retriever = self.vector_retriever
else:
self.retriever = self.vector_retriever
# Initialize Agent Graph if enabled
self.agent_executor = None
self.code_analyzer = None
if self.use_agent:
try:
from code_chatbot.agents.agent_workflow import create_agent_graph
from code_chatbot.analysis.ast_analysis import EnhancedCodeAnalyzer
import os
logger.info(f"Building Agentic Workflow Graph for {self.repo_dir}...")
# Try to load code analyzer from saved graph
graph_path = os.path.join(self.repo_dir, "ast_graph.graphml") if self.repo_dir else None
if graph_path and os.path.exists(graph_path):
try:
import networkx as nx
self.code_analyzer = EnhancedCodeAnalyzer()
self.code_analyzer.graph = nx.read_graphml(graph_path)
logger.info(f"Loaded code analyzer with {self.code_analyzer.graph.number_of_nodes()} nodes")
except Exception as e:
logger.warning(f"Failed to load code analyzer: {e}")
self.agent_executor = create_agent_graph(
self.llm, self.retriever, self.repo_name,
self.repo_dir, self.provider, self.code_analyzer
)
except Exception as e:
logger.error(f"Failed to build Agent Graph: {e}")
self.use_agent = False
def _get_llm(self):
"""Initialize the LLM based on provider (only Groq and Gemini supported)."""
api_key = self.api_key or os.getenv(f"{self.provider.upper()}_API_KEY")
if self.provider == "gemini":
if not api_key:
if not os.getenv("GOOGLE_API_KEY"):
raise ValueError("Google API Key is required for Gemini")
# Fallback list of Gemini models to try in order
GEMINI_MODELS_TO_TRY = [
"gemini-3-flash-preview",
"gemini-3-pro-preview",
"gemini-2.5-flash",
"gemini-2.5-pro",
"gemini-2.5-flash-preview-09-2025",
"gemini-2.5-flash-lite",
"gemini-2.5-flash-lite-preview-09-2025",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-pro",
]
# If user specified a model, try it first
if self.model_name:
model_name = self.model_name
if model_name.startswith("models/"):
model_name = model_name.replace("models/", "")
if model_name not in GEMINI_MODELS_TO_TRY:
GEMINI_MODELS_TO_TRY.insert(0, model_name)
else:
# Move specified model to front
GEMINI_MODELS_TO_TRY.remove(model_name)
GEMINI_MODELS_TO_TRY.insert(0, model_name)
# Try each model until one works
last_error = None
last_working_model = None
for model_name in GEMINI_MODELS_TO_TRY:
try:
logger.info(f"Attempting to use Gemini model: {model_name}")
llm = ChatGoogleGenerativeAI(
model=model_name,
google_api_key=api_key,
temperature=0.2,
convert_system_message_to_human=True
)
# Don't test the model here - it uses up quota!
# Just return it and let the actual call determine if it works
logger.info(f"Initialized Gemini model: {model_name}")
return llm
except Exception as e:
error_str = str(e).lower()
# Check for specific error types
if "not_found" in error_str or "404" in error_str:
logger.warning(f"Model {model_name} not found, trying next...")
elif "resource_exhausted" in error_str or "429" in error_str or "quota" in error_str:
logger.warning(f"Model {model_name} rate limited, trying next...")
else:
logger.warning(f"Model {model_name} failed: {str(e)[:100]}")
last_error = e
continue
# If all models failed, raise the last error
raise ValueError(f"All Gemini models failed. Last error: {last_error}")
elif self.provider == "groq":
if not api_key:
if not os.getenv("GROQ_API_KEY"):
raise ValueError("Groq API Key is required")
return ChatGroq(
model=self.model_name or "llama-3.3-70b-versatile",
groq_api_key=api_key,
temperature=0.2
)
else:
raise ValueError(f"Provider {self.provider} not supported. Only 'groq' and 'gemini' are supported.")
def _try_next_gemini_model(self) -> bool:
"""
Try to switch to the next Gemini model in the fallback list.
Returns True if a new model was set, False if all models exhausted.
"""
if self.provider != "gemini":
return False
self._gemini_model_index += 1
if self._gemini_model_index >= len(GEMINI_FALLBACK_MODELS):
logger.error("All Gemini models exhausted!")
return False
next_model = GEMINI_FALLBACK_MODELS[self._gemini_model_index]
logger.info(f"Switching to next Gemini model: {next_model} (index {self._gemini_model_index})")
api_key = self.api_key or os.getenv("GOOGLE_API_KEY")
try:
self.llm = ChatGoogleGenerativeAI(
model=next_model,
google_api_key=api_key,
temperature=0.2,
convert_system_message_to_human=True
)
self.model_name = next_model
# Rebuild agent if using agents
if self.use_agent:
try:
from code_chatbot.agents.agent_workflow import create_agent_graph
self.agent_executor = create_agent_graph(
llm=self.llm,
retriever=self.vector_retriever,
code_analyzer=self.code_analyzer
)
except Exception as e:
logger.warning(f"Could not rebuild agent: {e}")
return True
except Exception as e:
logger.error(f"Failed to switch to model {next_model}: {e}")
return self._try_next_gemini_model() # Recursively try next
def _build_rag_chain(self):
"""Builds a simplified RAG chain with history-aware retrieval."""
# For compatibility, we'll use a simpler approach that works with current langchain
# The history-aware retriever will be implemented in the chat method
return None # We'll handle retrieval manually in chat()
def _contextualize_query(self, question: str, history: List) -> str:
"""Contextualize query based on chat history."""
if not history:
return question
# Build context from history
history_text = ""
for i in range(0, len(history), 2):
if i < len(history) and isinstance(history[i], HumanMessage):
history_text += f"User: {history[i].content}\n"
if i + 1 < len(history) and isinstance(history[i + 1], AIMessage):
history_text += f"Assistant: {history[i + 1].content}\n"
# Simple contextualization - just use the question for now
# In a full implementation, you'd use an LLM to rewrite the query
return question # Simplified
def chat(self, question: str) -> Tuple[str, List[dict]]:
"""
Ask a question to the chatbot.
Uses Agentic Workflow if enabled, otherwise falls back to Linear RAG.
"""
try:
# 1. Agentic Mode
if self.use_agent and self.agent_executor:
logger.info("Executing Agentic Workflow...")
# Contextualize with history
# Use comprehensive system prompt for high-quality answers
from code_chatbot.core.prompts import get_prompt_for_provider
sys_content = get_prompt_for_provider("system_agent", self.provider).format(repo_name=self.repo_name)
system_msg = SystemMessage(content=sys_content)
# Token Optimization: Only pass last 4 messages (2 turns) to keep context light.
recent_history = self.chat_history[-4:] if self.chat_history else []
inputs = {
"messages": [system_msg] + recent_history + [HumanMessage(content=question)]
}
# Run the graph
try:
final_state = self.agent_executor.invoke(inputs, config={"recursion_limit": 20})
# Extract Answer
messages = final_state["messages"]
raw_content = messages[-1].content
# Handle Gemini's multi-part content
if isinstance(raw_content, list):
answer = ""
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
answer += block.get('text', '')
elif isinstance(block, str):
answer += block
answer = answer.strip() or str(raw_content)
else:
answer = raw_content
# CLEANING: Remove hallucinated source chips
answer = self._clean_response(answer)
# Update history
self.chat_history.append(HumanMessage(content=question))
self.chat_history.append(AIMessage(content=answer))
if len(self.chat_history) > 20: self.chat_history = self.chat_history[-20:]
return answer, []
except Exception as e:
# Fallback for Groq/LLM Tool Errors & Rate Limits
error_str = str(e)
# Check if it's a rate limit error
if any(err in error_str for err in ["429", "RESOURCE_EXHAUSTED", "quota"]):
logger.warning(f"Rate limit hit on {self.model_name}: {error_str[:100]}")
# Try switching to next Gemini model
if self.provider == "gemini" and self._try_next_gemini_model():
logger.info(f"Switched to {self.model_name}, retrying...")
return self.chat(question) # Retry with new model
else:
logger.warning("No more models to try, falling back to Linear RAG")
return self._linear_chat(question)
# Handle tool use errors
if any(err in error_str for err in ["tool_use_failed", "invalid_request_error", "400"]):
logger.warning(f"Agent failed ({error_str}), falling back to Linear RAG.")
return self._linear_chat(question)
raise e
# 2. Linear RAG Mode (Fallback)
return self._linear_chat(question)
except Exception as e:
# Check for rate limits in outer exception too
error_str = str(e)
if any(err in error_str for err in ["429", "RESOURCE_EXHAUSTED", "quota"]):
if self.provider == "gemini" and self._try_next_gemini_model():
logger.info(f"Switched to {self.model_name} after outer error, retrying...")
return self.chat(question)
logger.error(f"Error during chat: {e}", exc_info=True)
return f"Error: {str(e)}", []
def _clean_response(self, text: str) -> str:
"""Clean response from hallucinated HTML/CSS artifacts."""
if not text:
return ""
import re
# Remove the specific div block structure for source chips
clean_text = re.sub(r'<div class="source-chip">.*?</div>\s*</div>', '', text, flags=re.DOTALL)
# Remove standalone chips if any remain
clean_text = re.sub(r'<div class="source-chip">.*?</div>', '', clean_text, flags=re.DOTALL)
# Remove source-container divs
clean_text = re.sub(r'<div class="source-container">.*?</div>', '', clean_text, flags=re.DOTALL)
return clean_text.strip()
def _linear_chat(self, question: str) -> Tuple[str, List[dict]]:
"""Linear RAG fallback."""
messages, sources, _ = self._prepare_chat_context(question)
if not messages:
return "I don't have any information about this codebase. Please make sure the codebase has been indexed properly.", []
# Get response from LLM
try:
response_msg = self.llm.invoke(messages)
answer = self._clean_response(response_msg.content)
except Exception as e:
# Check for Rate Limit in Linear Chat
error_str = str(e)
if any(err in error_str for err in ["429", "RESOURCE_EXHAUSTED", "quota"]):
if self.provider == "gemini" and self._try_next_gemini_model():
logger.info(f"Linear RAG: Switched to {self.model_name} due to rate limit, retrying...")
return self._linear_chat(question) # Retry with new model
logger.error(f"Error in linear chat invoke: {e}")
return f"Error consuming LLM: {e}", []
# Update chat history
self.chat_history.append(HumanMessage(content=question))
self.chat_history.append(AIMessage(content=answer))
# Keep history manageable (last 20 messages)
if len(self.chat_history) > 20:
self.chat_history = self.chat_history[-20:]
return answer, sources
def _generate_file_tree_str(self):
"""Generate a string representation of the file tree."""
if not self.repo_files:
return ""
# Generate simple list of relative paths
paths = set()
for f in self.repo_files:
# Clean path
if self.repo_dir and f.startswith(self.repo_dir):
rel = os.path.relpath(f, self.repo_dir)
else:
rel = f
paths.add(rel)
tree_str = "Project Structure (File Tree):\n" + "\n".join(sorted(list(paths)))
return tree_str
def _prepare_chat_context(self, question: str):
"""Prepare messages and sources for chat/stream."""
# 1. Retrieve relevant documents
query_for_retrieval = question
if len(question) < 5 and len(self.chat_history) > 0:
# Enhance short queries with history
query_for_retrieval = f"{self.chat_history[-1].content} {question}"
# Increase retrieval limit to 30 docs since Gemini has large context
# FIXED: Use .invoke() instead of .get_relevant_documents() (deprecated/removed in LC 0.1)
docs = self.retriever.invoke(query_for_retrieval)
if not docs:
# Return empty context if no docs found
return None, [], ""
# Build context from documents - Use FULL content, not truncated
context_parts = []
for doc in docs[:30]: # Use top 30 documents
file_path = doc.metadata.get('file_path', 'unknown')
content = doc.page_content
context_parts.append(f"File: {file_path}\nWait, content:\n{content}\n---")
context_text = "\n\n".join(context_parts)
# Inject File Tree into context
file_tree = self._generate_file_tree_str()
full_context = f"{file_tree}\n\nRETRIEVED CONTEXT:\n{context_text}"
# Extract sources
sources = []
for doc in docs[:30]:
file_path = doc.metadata.get("file_path") or doc.metadata.get("source", "unknown")
sources.append({
"file_path": file_path,
"url": doc.metadata.get("url", f"file://{file_path}"),
})
# Build prompt with history - use provider-specific prompt
from code_chatbot.core.prompts import get_prompt_for_provider
base_prompt = get_prompt_for_provider("linear_rag", self.provider)
qa_system_prompt = base_prompt.format(
repo_name=self.repo_name,
context=full_context
)
# Build messages with history
messages = [SystemMessage(content=qa_system_prompt)]
# Add chat history
for msg in self.chat_history[-10:]: # Last 10 messages for context
messages.append(msg)
# Add current question
messages.append(HumanMessage(content=question))
return messages, sources, context_text
def stream_chat(self, question: str):
"""Streaming chat method returning (generator, sources)."""
messages, sources, _ = self._prepare_chat_context(question)
if not messages:
def empty_gen(): yield "I don't have any information about this codebase."
return empty_gen(), []
# Update history with USER message immediately
self.chat_history.append(HumanMessage(content=question))
if len(self.chat_history) > 20: self.chat_history = self.chat_history[-20:]
# Generator wrapper to capture full response for history
def response_generator():
full_response = ""
for chunk in self.llm.stream(messages):
content = chunk.content
full_response += content
yield content
# Update history with AI message after generation
clean_full_response = self._clean_response(full_response)
self.chat_history.append(AIMessage(content=clean_full_response))
return response_generator(), sources
def clear_memory(self):
"""Clear the conversation history."""
self.chat_history.clear()