Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import ast | |
| from .state import State | |
| from .tools import retrieve_tool | |
| from langchain_core.messages import SystemMessage | |
| from utils.models_loader import llm , ST | |
| from utils.data_loader import load_influencer_data | |
| from groq import Groq | |
| import os | |
| from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt | |
| from langgraph.prebuilt import create_react_agent | |
| from pydantic import BaseModel , Field | |
| from langchain_core.tools import tool | |
| from .state import BrainstromTopicFormatter | |
| def caption_image(state: State) -> State: | |
| if len(state.images)>0: | |
| if state.images[-1]!=None: | |
| print('Captioning image') | |
| client = Groq(api_key=os.environ.get('GROQ_API_KEY')) | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": image_captioning_prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{state.images[-1]}", | |
| }, | |
| }, | |
| ], | |
| } | |
| ], | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| ) | |
| response=chat_completion.choices[0].message.content | |
| state.image_captions.append(response) | |
| return state | |
| else: | |
| state.images.append(None) | |
| state.image_captions.append(None) | |
| return state | |
| def retrieve(state: State) -> State: | |
| print('Moving to retrieval process') | |
| retrievals=[] | |
| if len(state.latest_preferred_topics)==0: | |
| for topic in state.topic: # Loop through each topic | |
| embedded_query = ST.encode(topic) # Embed each topic | |
| data = load_influencer_data() | |
| scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1) | |
| # Construct a list of dictionaries for this topic | |
| result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])] | |
| retrievals.append(result) | |
| print('Retrieval process completed......') | |
| state.retrievals.append(retrievals) | |
| if len (state.latest_preferred_topics)>0: | |
| print('The preferred_topics are:',state.latest_preferred_topics) | |
| state.preferred_topics.append(state.latest_preferred_topics) | |
| for topic in state.preferred_topics[-1]: # Loop through each topic | |
| embedded_query = ST.encode(topic) # Embed each topic | |
| data = load_influencer_data() | |
| scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1) | |
| # Construct a list of dictionaries for this topic | |
| result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])] | |
| retrievals.append(result) | |
| print('Retrieval process completed for preferred_topics......') | |
| state.latest_preferred_topics=[] | |
| state.retrievals.append(retrievals) | |
| # print('The retrieval is:\n',state.retrievals ) | |
| # return State(messages="Retrieved",topic=state.topic,retrievals=state.retrievals) | |
| return state | |
| def generate_story(state:State)-> State: | |
| tools=[retrieve_tool] | |
| react_agent=create_react_agent( | |
| model=llm.bind_tools(tools), | |
| tools=tools | |
| ) | |
| if len(state.preferred_topics)==0: | |
| template = initial_story_prompt(state) | |
| else: | |
| template = refined_story_prompt(state) | |
| # and {state.image_captions[-1]} | |
| messages = [SystemMessage(content=template)] | |
| response = react_agent.invoke({'messages':messages}) | |
| response = response['messages'][-1].content | |
| state.stories.append(response) | |
| # return State(messages="Story generated", topic=state.topic,stories=state.stories) | |
| return state | |
| def generate_brainstroming(state:State)-> State: | |
| template= brainstroming_prompt(state) | |
| messages = [SystemMessage(content=template)] | |
| response = llm.with_structured_output(BrainstromTopicFormatter).invoke(messages) | |
| response = response.model_dump() | |
| state.brainstroming_topics.append(response) | |
| print('The brainstroming topics are:',state.brainstroming_topics) | |
| # return State(messages="Story generated",topic=state.topic,brainstroming_topics=state.brainstroming_topics) | |
| return state | |
| def select_preferred_topics(state: State)-> State: | |
| print("---human_feedback---") | |
| topic_values = list(state.brainstroming_topics[-1].values()) | |
| print("Available topics:") | |
| for idx, topic in enumerate(topic_values, 1): | |
| print(f"{idx}. {topic}") | |
| raw_input_str = input("Enter the numbers of your preferred topics (comma-separated), or press Enter to skip: ").strip() | |
| if not raw_input_str: | |
| state.carry_on=False | |
| print("No topics selected. Ending process.") | |
| return state | |
| try: | |
| preferred_indices = [int(i.strip()) for i in raw_input_str.split(",")] | |
| preferred_topics = [topic_values[i - 1] for i in preferred_indices if 0 < i <= len(topic_values)] | |
| # preferred_topics = user_input | |
| state.preferred_topics.append(preferred_topics) | |
| except Exception: | |
| state.carry_on=False | |
| print("Invalid input. Please try again.") | |
| return state | |
| if not preferred_topics: | |
| state.carry_on=False | |
| print("No valid topics selected. Ending process.") | |
| return state | |
| print("You selected:") | |
| print(preferred_topics) | |
| state.carry_on=True | |
| return state | |
| def route_after_selection(state:State): | |
| if len(state.latest_preferred_topics)==0: | |
| return False | |
| elif len(state.latest_preferred_topics)>0: | |
| return True | |