Spaces:
Sleeping
Sleeping
| from langchain_core.messages import SystemMessage | |
| from .tools import retrieve_tool | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from fastapi import UploadFile | |
| from huggingface_hub import InferenceClient | |
| from .prompts import story_to_prompt , final_story_prompt | |
| import os | |
| from langgraph.prebuilt import create_react_agent | |
| import pandas as pd | |
| from datasets import load_dataset | |
| import ast | |
| import faiss | |
| import re | |
| import numpy as np | |
| from utils.models_loader import ST , llm | |
| def generate_final_story(final_state): | |
| if len(final_state['preferred_topics'])>0: | |
| template = final_story_prompt(final_state) | |
| messages = [SystemMessage(content=template)] | |
| tools = [retrieve_tool] | |
| react_agent=create_react_agent( | |
| model=llm.bind_tools(tools), | |
| tools=tools) | |
| response = react_agent.invoke({'messages':messages}) | |
| response = response['messages'][-1].content | |
| return response | |
| else: | |
| return final_state['stories'][-1] | |
| def encode_image_to_base64(uploaded_file: UploadFile) -> str: | |
| return base64.b64encode(uploaded_file.file.read()).decode("utf-8") | |
| # Convert base64 string to PIL image (optional for LangGraph processing) | |
| def process_image(base64_str: str) -> Image.Image: | |
| image_data = base64.b64decode(base64_str) | |
| return Image.open(BytesIO(image_data)) | |
| def generate_prompt(final_story): | |
| print('************Entering prompt generator****************') | |
| messages = [ | |
| ( | |
| "system", | |
| story_to_prompt, | |
| ), | |
| ("human", final_story), | |
| ] | |
| prompt = llm.invoke(messages) | |
| print('The prompt is:',prompt) | |
| return prompt.content | |
| def generate_image(final_story): | |
| prompt = generate_prompt(final_story) | |
| print('************Finished prompt generator****************') | |
| client = InferenceClient( | |
| provider="hf-inference", | |
| api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'), | |
| ) | |
| print('************Finished calling generator****************') | |
| # output is a PIL.Image object | |
| image = client.text_to_image( | |
| prompt, | |
| model="black-forest-labs/FLUX.1-schnell", | |
| ) | |
| print('*****************Image Created*******************') | |
| image.save('image.png') | |
| print('*****************Image Saved*******************') | |
| return "Image Created" | |
| def save_to_db(business_details): | |
| dataset = load_dataset("subashdvorak/tiktok-agentic-story")['train'] | |
| # dataset = load_influencer_data() | |
| df = pd.DataFrame(dataset) | |
| # 2. Flatten all business detail values to a set of lowercase strings | |
| all_values = set() | |
| for v in business_details.values(): | |
| if isinstance(v, str): | |
| all_values.add(v.lower()) | |
| elif isinstance(v, list): | |
| all_values.update(map(str.lower, map(str, v))) | |
| # 3. Match rows where ANY column contains ANY of the values | |
| def row_matches(row): | |
| return any( | |
| str(cell).lower().find(val) != -1 | |
| for cell in row | |
| for val in all_values | |
| ) | |
| # 4. Apply row-wise matching | |
| matched_df = df[df.apply(row_matches, axis=1)] | |
| matched_df.to_csv('extracted_data.csv') | |
| def manual_retrieval(messages, business_details): | |
| # === Load CSV === | |
| csv_path = 'extracted_data.csv' | |
| df = pd.read_csv(csv_path) | |
| # === Parse stored embeddings === | |
| df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) | |
| embeddings = np.vstack(df['embeddings'].values).astype('float32') | |
| # === Build FAISS index === | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| # === Load SentenceTransformer model === | |
| # === Encode the query and search === | |
| query_embedding = ST.encode(str(messages)+str(business_details)).reshape(1, -1).astype('float32') | |
| top_k=3 | |
| distances, indices = index.search(query_embedding, top_k) | |
| # === Function to extract sections 1 and 6 === | |
| def extract_story_and_branding(full_story): | |
| full_story = full_story.replace('**6. Visible Texts or Brandings**', '**6. Visible Texts or Brandings:**') | |
| full_story = full_story.replace('**1. Story**', '**1. Story:**') | |
| pattern = ( | |
| r"\*\*1\. Story:\*\*(.*?)(?=\*\*\d+\.\s)" | |
| r".*?" | |
| r"\*\*6\. Visible Texts or Brandings:\*\*(.*?)(?=\*\*\d+\.\s|$)" | |
| ) | |
| match = re.search(pattern, full_story, re.DOTALL) | |
| if match: | |
| story_section = match.group(1).strip() | |
| branding_section = match.group(2).strip() | |
| return f"Story:\n{story_section}\n\nVisible Texts or Brandings:\n{branding_section}" | |
| else: | |
| return "Requested sections not found." | |
| # === Format results === | |
| outer_list = [] | |
| for i, idx in enumerate(indices[0]): | |
| res = { | |
| 'rank': i + 1, | |
| 'username': df.iloc[idx]['username'], | |
| 'agentic_story': df.iloc[idx]['agentic_story'], | |
| 'likesCount': df.iloc[idx]['likesCount'], | |
| 'commentCount': df.iloc[idx]['commentCount'], | |
| 'distance': distances[0][i] | |
| } | |
| inner_list = [] | |
| inner_list.append(f"[{res['rank']}]. The influencer name is: **{res['username']}** — Likes: **{res['likesCount']}**, Comments: **{res['commentCount']}**") | |
| inner_list.append(f"The story of that particular video is:\n{extract_story_and_branding(res['agentic_story'])}") | |
| inner_list.append(f"Distance: {res['distance']:.4f}") | |
| outer_list.append(inner_list) | |
| return str(outer_list) | |