Spaces:
Sleeping
Sleeping
| import requests | |
| from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage | |
| from .state import State | |
| from .tools import RetrieverBackup | |
| from .schemas import ParameterFormatter, EndpointFormatter | |
| from .prompts import query_check_prompt, fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt | |
| from .utils import process_query, get_endpoint_info | |
| from src.genai.utils.models_loader import llm_gpt | |
| import numpy as np | |
| from src.genai.utils.data_loader import api_knowledge_df, api_index, caption_df , caption_index | |
| from src.genai.utils.models_loader import embedding_model | |
| from ..handlers import ( | |
| compare, | |
| get_posting_time, | |
| get_peak_comment_hour, | |
| get_emoji_count, | |
| get_comment_quality, | |
| get_bot_and_diversity, | |
| ) | |
| class FetchLastMessage: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run (self, state:State): | |
| print('Message:',state['messages']) | |
| template = fetch_last_message_prompt | |
| messages=[SystemMessage(content=template)]+state['messages'] | |
| result = self.llm.invoke(messages) | |
| print('Latest Message:', process_query(result.content)) | |
| if len(state['messages'])>11: | |
| state["messages"] = state["messages"][-9:] | |
| return { | |
| 'latest_message': process_query(result.content) | |
| } | |
| class RetrievePossibleEndpoints: | |
| def __init__(self): | |
| self.df = api_knowledge_df | |
| self.index = api_index | |
| self.results = [] | |
| def run(self,state:State): | |
| print('Gone to retrieve possible endpoints') | |
| query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32') | |
| distances, indices = self.index.search(query_embedding,10) | |
| for idx in indices[0]: | |
| row = self.df.iloc[idx] | |
| print('Endpoint:',row['endpoint']) | |
| self.results.append(row['endpoint']) | |
| print('The possible endpoints are:', self.results) | |
| return { | |
| "possible_endpoints": self.results, | |
| } | |
| class RetrieveExactEndpoint: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self,state:State): | |
| print('Gone to retrieve exact endpoint') | |
| messages = [SystemMessage(content=fetch_endpoint_prompt), | |
| FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''), | |
| HumanMessage(content=f'''The user query is: {state['latest_message']}''')] | |
| result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages) | |
| print('The exact endpoint is:', result.endpoint) | |
| endpoint_info=get_endpoint_info(result.endpoint) | |
| print('The endpoint info is:', endpoint_info) | |
| return { | |
| "messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}], | |
| "endpoint": result.endpoint, | |
| "method": endpoint_info['method'], | |
| "needed_parameters": endpoint_info["parameters"] | |
| } | |
| class QueryCheckNode: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self, state:State): | |
| try: | |
| print('Entered to query checking') | |
| messages = [SystemMessage(content=query_check_prompt), | |
| HumanMessage(content=f'''The user query is: {state['latest_message']}''')] | |
| result = self.llm.invoke(messages) | |
| print(result.content) | |
| return{'query_type': result.content} | |
| except Exception as e: | |
| print('Error occoured:', e) | |
| return {'error_message': str(e)} | |
| class FetchParametersNode: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| self.complex_endpoints=['/api/v1/compare/','/api/v1/engagement/posting-time-analysis','/api/v1/audience/peak-comment-hour','/api/v1/audience/emoji-count','/api/v1/audience/comment-quality'] | |
| def run(self , state:State): | |
| try: | |
| print('Entered to fetch parameters') | |
| if state['endpoint'] not in self.complex_endpoints: | |
| template = fetch_parameters_prompt | |
| messages=[SystemMessage(content=template),HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')] | |
| # print('messages:', messages) | |
| result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages) | |
| parameters_values = {k: (process_query(v) if isinstance(v, str) else v) for k, v in result.parameters_values.items()} | |
| # if 'single_influencer_query' in state['query_type']: | |
| # print('The parameter values:', parameters_values) | |
| # return { | |
| # 'parameters_values':parameters_values | |
| # } | |
| # elif 'aggregate_query' in state['query_type']: | |
| # parameters_values['influencer_username'] = ['divyadhakal_','munachiya','mydarlingfood','_its.me.muskan_'] | |
| # print('The parameter values:', parameters_values) | |
| # return{ | |
| # 'parameters_values': parameters_values | |
| # } | |
| print('The parameter values:', parameters_values) | |
| return {'parameters_values': parameters_values} | |
| except Exception as e: | |
| print('Error occoured:', e) | |
| return {'error_message': str(e)} | |
| class FetchDataNode: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| self.base_url = 'https://reveltrends.vercel.app' | |
| self.headers = { | |
| "Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed | |
| "Content-Type": "application/json" | |
| } | |
| self.endpoint_handlers = { | |
| '/api/v1/compare/': compare, | |
| '/api/v1/engagement/posting-time-analysis': get_posting_time, | |
| '/api/v1/audience/peak-comment-hour': get_peak_comment_hour, | |
| '/api/v1/audience/emoji-count': get_emoji_count, | |
| '/api/v1/audience/comment-quality': get_comment_quality, | |
| '/api/v1/audience/bot-and-diversity': get_bot_and_diversity | |
| } | |
| def run(self, state:State): | |
| try: | |
| state['query_type']='single_influencer_query' | |
| print('Entered to fetch data') | |
| url = f'''{self.base_url}{state['endpoint']}''' | |
| if state['endpoint'] in self.endpoint_handlers: | |
| print('Entered to handler.') | |
| handler = self.endpoint_handlers[state['endpoint']] | |
| response = handler(state, llm_gpt, url) | |
| print('Returned by handler.') | |
| return {'response':response.json()} | |
| elif 'single_influencer_query' in state['query_type']: | |
| response = requests.get(url, params=state['parameters_values'],headers=self.headers) | |
| print('Data from api:', response) | |
| return {'response':response.json()} | |
| # elif 'aggregate_query' in state['query_type']: | |
| # print('Entered to aggregrated query execution') | |
| # print(state['parameters_values']) | |
| # params = state["parameters_values"] | |
| # if "influencer_username" in params and isinstance(params["influencer_username"], list): | |
| # results = {} | |
| # # Iterate through each influencer username | |
| # for username in params["influencer_username"]: | |
| # current_params = params.copy() | |
| # current_params["influencer_username"] = username | |
| # response = requests.get(url, params=current_params, headers=self.headers) | |
| # results[username] = response.json() # Store influencer-wise response | |
| # print('Data from api:', response) | |
| # return {"response": results} | |
| except Exception as e: | |
| print('Error occoured:', e) | |
| return {'error_message': str(e), 'response': 'No response'} | |
| class BackupRetrievalNode: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self, state:State): | |
| retrieval=RetrieverBackup().retrieve(state['latest_message']) | |
| return {'backup_data': retrieval} | |
| class BackupRoutingNode: | |
| def __init__(self): | |
| pass | |
| def run(self,state:State): | |
| if state.get('error_message') is not None: | |
| return 'execute_backup' | |
| else: | |
| return 'go_on' | |