subashpoudel's picture
Changed the entire project structure
93a5bf9
raw
history blame
5.59 kB
from langchain_core.messages import SystemMessage
from .tools import retrieve_tool
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
import pandas as pd
from datasets import load_dataset
import ast
import faiss
import re
import numpy as np
from utils.models_loader import ST , llm
def generate_final_story(final_state):
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template)]
tools = [retrieve_tool]
react_agent=create_react_agent(
model=llm.bind_tools(tools),
tools=tools)
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
return response
else:
return final_state['stories'][-1]
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
# Convert base64 string to PIL image (optional for LangGraph processing)
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def generate_prompt(final_story):
print('************Entering prompt generator****************')
messages = [
(
"system",
story_to_prompt,
),
("human", final_story),
]
prompt = llm.invoke(messages)
print('The prompt is:',prompt)
return prompt.content
def generate_image(final_story):
prompt = generate_prompt(final_story)
print('************Finished prompt generator****************')
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'),
)
print('************Finished calling generator****************')
# output is a PIL.Image object
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-schnell",
)
print('*****************Image Created*******************')
image.save('image.png')
print('*****************Image Saved*******************')
return "Image Created"
def save_to_db(business_details):
dataset = load_dataset("subashdvorak/tiktok-agentic-story")['train']
# dataset = load_influencer_data()
df = pd.DataFrame(dataset)
# 2. Flatten all business detail values to a set of lowercase strings
all_values = set()
for v in business_details.values():
if isinstance(v, str):
all_values.add(v.lower())
elif isinstance(v, list):
all_values.update(map(str.lower, map(str, v)))
# 3. Match rows where ANY column contains ANY of the values
def row_matches(row):
return any(
str(cell).lower().find(val) != -1
for cell in row
for val in all_values
)
# 4. Apply row-wise matching
matched_df = df[df.apply(row_matches, axis=1)]
matched_df.to_csv('extracted_data.csv')
def manual_retrieval(messages, business_details):
# === Load CSV ===
csv_path = 'extracted_data.csv'
df = pd.read_csv(csv_path)
# === Parse stored embeddings ===
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
embeddings = np.vstack(df['embeddings'].values).astype('float32')
# === Build FAISS index ===
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
# === Load SentenceTransformer model ===
# === Encode the query and search ===
query_embedding = ST.encode(str(messages)+str(business_details)).reshape(1, -1).astype('float32')
top_k=3
distances, indices = index.search(query_embedding, top_k)
# === Function to extract sections 1 and 6 ===
def extract_story_and_branding(full_story):
full_story = full_story.replace('**6. Visible Texts or Brandings**', '**6. Visible Texts or Brandings:**')
full_story = full_story.replace('**1. Story**', '**1. Story:**')
pattern = (
r"\*\*1\. Story:\*\*(.*?)(?=\*\*\d+\.\s)"
r".*?"
r"\*\*6\. Visible Texts or Brandings:\*\*(.*?)(?=\*\*\d+\.\s|$)"
)
match = re.search(pattern, full_story, re.DOTALL)
if match:
story_section = match.group(1).strip()
branding_section = match.group(2).strip()
return f"Story:\n{story_section}\n\nVisible Texts or Brandings:\n{branding_section}"
else:
return "Requested sections not found."
# === Format results ===
outer_list = []
for i, idx in enumerate(indices[0]):
res = {
'rank': i + 1,
'username': df.iloc[idx]['username'],
'agentic_story': df.iloc[idx]['agentic_story'],
'likesCount': df.iloc[idx]['likesCount'],
'commentCount': df.iloc[idx]['commentCount'],
'distance': distances[0][i]
}
inner_list = []
inner_list.append(f"[{res['rank']}]. The influencer name is: **{res['username']}** — Likes: **{res['likesCount']}**, Comments: **{res['commentCount']}**")
inner_list.append(f"The story of that particular video is:\n{extract_story_and_branding(res['agentic_story'])}")
inner_list.append(f"Distance: {res['distance']:.4f}")
outer_list.append(inner_list)
return str(outer_list)