subashpoudel's picture
Included CI CD
583f6dd
raw
history blame
4.71 kB
from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage
from .tools import retrieve_tool
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
import pandas as pd
from datasets import load_dataset
from src.genai.utils.models_loader import llm
def generate_final_story(final_state):
if 'preferred_topics' in final_state:
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The idea of the video is:\n{final_state['idea']}\n '''),
ToolMessage(content=f'''The business details is:\n{final_state['business_details']}\nThe data of influencers is:\n{final_state['retrievals'][-1]}''',tool_call_id='final_story_tool')]
print('The message of final story:',messages)
react_agent=create_react_agent(
model=llm,
tools=[])
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
return response
else:
return final_state['stories'][-1]
else:
template = final_story_prompt(final_state)
influencers_data = retrieve_tool(final_state)
messages = [SystemMessage(content=template),
ToolMessage(content=f'''The business details is:\n{str(final_state)}\nThe data of influencers is:\n{influencers_data}''',tool_call_id='final_story_tool')]
react_agent=create_react_agent(
model=llm,
tools=[])
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
return response
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
# Convert base64 string to PIL image (optional for LangGraph processing)
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def generate_prompt(final_story,business_details,refined_ideation):
print('************Entering prompt generator****************')
messages = [SystemMessage(content=story_to_prompt()),
HumanMessage(content=f'''The scene-by-scene video story is {final_story}'''),
ToolMessage(content=f'''The business details is:\n{business_details}\nThe idea is{refined_ideation}''',tool_call_id='prompt_generation_id')
]
prompt = llm.invoke(messages)
print('The prompt is:',prompt)
return prompt.content
def generate_image(final_story, business_details, refined_ideation):
prompt = generate_prompt(final_story, business_details, refined_ideation)
print('************Finished prompt generator****************')
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'),
)
print('************Finished calling generator****************')
# output is a PIL.Image object
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-schnell",
)
print('*****************Image Created*******************')
# Convert image to BytesIO buffer
buffered = BytesIO()
image.save(buffered, format="PNG") # you can also use "JPEG" if preferred
buffered.seek(0)
# Encode to base64 string
img_base64 = base64.b64encode(buffered.read()).decode("utf-8")
print('*****************Image Encoded to Base64*******************')
return img_base64
def save_to_db(business_details):
dataset = load_dataset("subashdvorak/tiktok-agentic-story")['train']
# dataset = load_influencer_data()
df = pd.DataFrame(dataset)
# 2. Flatten all business detail values to a set of lowercase strings
all_values = set()
for v in business_details.values():
if isinstance(v, str):
all_values.add(v.lower())
elif isinstance(v, list):
all_values.update(map(str.lower, map(str, v)))
# 3. Match rows where ANY column contains ANY of the values
def row_matches(row):
return any(
str(cell).lower().find(val) != -1
for cell in row
for val in all_values
)
# 4. Apply row-wise matching
matched_df = df[df.apply(row_matches, axis=1)]
matched_df.to_csv('extracted_data.csv')