import requests from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage from .state import State from .tools import RetrieverBackup from .schemas import ParameterFormatter, EndpointFormatter from .prompts import query_check_prompt, fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt from .utils import process_query, get_endpoint_info from src.genai.utils.models_loader import llm_gpt import numpy as np from src.genai.utils.data_loader import api_knowledge_df, api_index, caption_df , caption_index from src.genai.utils.models_loader import embedding_model from ..handlers.compare import compare from ..handlers.posting_time import get_posting_time from ..handlers.peak_comment_hour import get_peak_comment_hour from ..handlers.emoji_count import get_emoji_count from ..handlers.comment_quality import get_comment_quality class FetchLastMessage: def __init__(self): self.llm = llm_gpt def run (self, state:State): print('Message:',state['messages']) template = fetch_last_message_prompt messages=[SystemMessage(content=template)]+state['messages'] result = self.llm.invoke(messages) print('Latest Message:', process_query(result.content)) if len(state['messages'])>11: state["messages"] = state["messages"][-9:] return { 'latest_message': process_query(result.content) } class RetrievePossibleEndpoints: def __init__(self): self.df = api_knowledge_df self.index = api_index # self.results = [] self.results = ['/api/v1/compare/', '/api/v1/engagement/basic-metrics', '/api/v1/content/hashtags-analysis', '/api/v1/audience/emoji-count', '/api/v1/engagement/temporal_analysis'] def run(self,state:State): print('Gone to retrieve possible endpoints') # query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32') # distances, indices = self.index.search(query_embedding, 5) # for idx in indices[0]: # row = self.df.iloc[idx] # print('Endpoint:',row['endpoint']) # self.results.append(row['endpoint']) print('The possible endpoints are:', self.results) return { "possible_endpoints": self.results, } class RetrieveExactEndpoint: def __init__(self): self.llm = llm_gpt def run(self,state:State): print('Gone to retrieve exact endpoint') messages = [SystemMessage(content=fetch_endpoint_prompt), FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''), HumanMessage(content=f'''The user query is: {state['latest_message']}''')] result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages) print('The exact endpoint is:', result.endpoint) endpoint_info=get_endpoint_info(result.endpoint) print('The endpoint info is:', endpoint_info) return { "messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}], "endpoint": result.endpoint, "method": endpoint_info['method'], "needed_parameters": endpoint_info["parameters"] } class QueryCheckNode: def __init__(self): self.llm = llm_gpt def run(self, state:State): try: print('Entered to query checking') messages = [SystemMessage(content=query_check_prompt), HumanMessage(content=f'''The user query is: {state['latest_message']}''')] result = self.llm.invoke(messages) print(result.content) return{'query_type': result.content} except Exception as e: print('Error occoured:', e) return {'error_message': str(e)} class FetchParametersNode: def __init__(self): self.llm = llm_gpt self.complex_endpoints=['/api/v1/compare/','/api/v1/engagement/posting-time-analysis','/api/v1/audience/peak-comment-hour','/api/v1/audience/emoji-count','/api/v1/audience/comment-quality'] def run(self , state:State): try: print('Entered to fetch parameters') if state['endpoint'] not in self.complex_endpoints: template = fetch_parameters_prompt messages=[SystemMessage(content=template),HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')] # print('messages:', messages) result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages) parameters_values = {k: (process_query(v) if isinstance(v, str) else v) for k, v in result.parameters_values.items()} if 'single_influencer_query' in state['query_type']: print('The parameter values:', parameters_values) return { 'parameters_values':parameters_values } elif 'aggregate_query' in state['query_type']: parameters_values['influencer_username'] = ['divyadhakal_','munachiya','mydarlingfood','_its.me.muskan_'] print('The parameter values:', parameters_values) return{ 'parameters_values': parameters_values } except Exception as e: print('Error occoured:', e) return {'error_message': str(e)} class FetchDataNode: def __init__(self): self.llm = llm_gpt self.base_url = 'https://reveltrends.vercel.app' self.headers = { "Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed "Content-Type": "application/json" } def run(self, state:State): try: print('Entered to fetch data') url = f'''{self.base_url}{state['endpoint']}''' if state['endpoint'] == '/api/v1/compare/': response=compare(state,llm_gpt,url) return {'response': response.json()} elif state['endpoint'] == '/api/v1/engagement/posting-time-analysis': response = get_posting_time(state, llm_gpt,url) return {'response': response.json()} elif state['endpoint']=='/api/v1/audience/peak-comment-hour': response = get_peak_comment_hour(state,llm_gpt,url) return {'response':response.json()} elif state['endpoint']== '/api/v1/audience/emoji-count': response = get_emoji_count(state,llm_gpt,url) return {'response:',response.json()} elif state['endpoint']== '/api/v1/audience/comment-quality': response = get_comment_quality(state,llm_gpt,url) return {'response:',response.json()} elif 'single_influencer_query' in state['query_type']: response = requests.get(url, params=state['parameters_values'],headers=self.headers) print('Data from api:', response) return {'response':response.json()} elif 'aggregate_query' in state['query_type']: print('Entered to aggregrated query execution') print(state['parameters_values']) params = state["parameters_values"] if "influencer_username" in params and isinstance(params["influencer_username"], list): results = {} # Iterate through each influencer username for username in params["influencer_username"]: current_params = params.copy() current_params["influencer_username"] = username response = requests.get(url, params=current_params, headers=self.headers) results[username] = response.json() # Store influencer-wise response print('Data from api:', response) return {"response": results} except Exception as e: print('Error occoured:', e) return {'error_message': str(e), 'response': 'No response'} class BackupRetrievalNode: def __init__(self): self.llm = llm_gpt def run(self, state:State): retrieval=RetrieverBackup().retrieve(state['latest_message']) return {'backup_data': retrieval} class BackupRoutingNode: def __init__(self): pass def run(self,state:State): if state.get('error_message') is not None: return 'execute_backup' else: return 'go_on'