import pandas as pd import ast from .state import State from .tools import retrieve_tool from langchain_core.messages import SystemMessage ,HumanMessage from utils.models_loader import llm , ST from utils.data_loader import load_influencer_data from groq import Groq import os from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt from langgraph.prebuilt import create_react_agent from pydantic import BaseModel , Field from langchain_core.tools import tool from .state import BrainstromTopicFormatter def caption_image(state: State) -> State: if len(state.images)>0: if state.images[-1]!=None: print('Captioning image') client = Groq(api_key=os.environ.get('GROQ_API_KEY')) chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": [ {"type": "text", "text": image_captioning_prompt}, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{state.images[-1]}", }, }, ], } ], model="meta-llama/llama-4-scout-17b-16e-instruct", max_completion_tokens=50, temperature = 1 ) response=chat_completion.choices[0].message.content state.image_captions.append(response) return state else: state.images.append(None) state.image_captions.append(None) return state def retrieve(state: State) -> State: print('Moving to retrieval process') retrievals=[] if len(state.latest_preferred_topics)==0: for topic in state.topic: # Loop through each topic embedded_query = ST.encode(topic) # Embed each topic data = load_influencer_data() scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1) # Construct a list of dictionaries for this topic result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['story'])] retrievals.append(result) print('Retrieval process completed......') state.retrievals.append(retrievals) if len (state.latest_preferred_topics)>0: print('The preferred_topics are:',state.latest_preferred_topics) state.preferred_topics.append(state.latest_preferred_topics) for topic in state.preferred_topics[-1]: # Loop through each topic embedded_query = ST.encode(topic) # Embed each topic data = load_influencer_data() scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1) # Construct a list of dictionaries for this topic result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['story'])] retrievals.append(result) print('Retrieval process completed for preferred_topics......') state.latest_preferred_topics=[] state.retrievals.append(retrievals) return state def generate_story(state:State)-> State: tools=[retrieve_tool] react_agent=create_react_agent( model=llm.bind_tools(tools), tools=tools ) if len(state.preferred_topics)==0: template = initial_story_prompt(state) else: template = refined_story_prompt(state) # and {state.image_captions[-1]} messages = [SystemMessage(content=template), HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n''')] response = react_agent.invoke({'messages':messages}) response = response['messages'][-1].content state.stories.append(response) # return State(messages="Story generated", topic=state.topic,stories=state.stories) return state def generate_brainstroming(state:State)-> State: template= brainstroming_prompt(state) messages = [SystemMessage(content=template)] response = llm.with_structured_output(BrainstromTopicFormatter).invoke(messages) response = response.model_dump() state.brainstroming_topics.append(response) print('The brainstroming topics are:',state.brainstroming_topics) # return State(messages="Story generated",topic=state.topic,brainstroming_topics=state.brainstroming_topics) return state def select_preferred_topics(state: State)-> State: print("---human_feedback---") topic_values = list(state.brainstroming_topics[-1].values()) print("Available topics:") for idx, topic in enumerate(topic_values, 1): print(f"{idx}. {topic}") raw_input_str = input("Enter the numbers of your preferred topics (comma-separated), or press Enter to skip: ").strip() if not raw_input_str: state.carry_on=False print("No topics selected. Ending process.") return state try: preferred_indices = [int(i.strip()) for i in raw_input_str.split(",")] preferred_topics = [topic_values[i - 1] for i in preferred_indices if 0 < i <= len(topic_values)] # preferred_topics = user_input state.preferred_topics.append(preferred_topics) except Exception: state.carry_on=False print("Invalid input. Please try again.") return state if not preferred_topics: state.carry_on=False print("No valid topics selected. Ending process.") return state print("You selected:") print(preferred_topics) state.carry_on=True return state def route_after_selection(state:State): if len(state.latest_preferred_topics)==0: return False elif len(state.latest_preferred_topics)>0: return True