Spaces:
Sleeping
Sleeping
| from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage | |
| from .tools import retrieve_tool | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from fastapi import UploadFile | |
| from huggingface_hub import InferenceClient | |
| from .prompts import story_to_prompt , final_story_prompt | |
| import os | |
| from langgraph.prebuilt import create_react_agent | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from src.genai.utils.models_loader import llm | |
| def generate_final_story(final_state): | |
| if 'preferred_topics' in final_state: | |
| if len(final_state['preferred_topics'])>0: | |
| template = final_story_prompt(final_state) | |
| messages = [SystemMessage(content=template), | |
| HumanMessage(content=f'''The idea of the video is:\n{final_state['idea']}\n '''), | |
| ToolMessage(content=f'''The business details is:\n{final_state['business_details']}\nThe data of influencers is:\n{final_state['retrievals'][-1]}''',tool_call_id='final_story_tool')] | |
| print('The message of final story:',messages) | |
| react_agent=create_react_agent( | |
| model=llm, | |
| tools=[]) | |
| response = react_agent.invoke({'messages':messages}) | |
| response = response['messages'][-1].content | |
| return response | |
| else: | |
| return final_state['stories'][-1] | |
| else: | |
| template = final_story_prompt(final_state) | |
| influencers_data = retrieve_tool(final_state) | |
| messages = [SystemMessage(content=template), | |
| ToolMessage(content=f'''The business details is:\n{str(final_state)}\nThe data of influencers is:\n{influencers_data}''',tool_call_id='final_story_tool')] | |
| react_agent=create_react_agent( | |
| model=llm, | |
| tools=[]) | |
| response = react_agent.invoke({'messages':messages}) | |
| response = response['messages'][-1].content | |
| return response | |
| def encode_image_to_base64(uploaded_file: UploadFile) -> str: | |
| return base64.b64encode(uploaded_file.file.read()).decode("utf-8") | |
| # Convert base64 string to PIL image (optional for LangGraph processing) | |
| def process_image(base64_str: str) -> Image.Image: | |
| image_data = base64.b64decode(base64_str) | |
| return Image.open(BytesIO(image_data)) | |
| def generate_prompt(final_story,business_details,refined_ideation): | |
| print('************Entering prompt generator****************') | |
| messages = [SystemMessage(content=story_to_prompt()), | |
| HumanMessage(content=f'''The scene-by-scene video story is {final_story}'''), | |
| ToolMessage(content=f'''The business details is:\n{business_details}\nThe idea is{refined_ideation}''',tool_call_id='prompt_generation_id') | |
| ] | |
| prompt = llm.invoke(messages) | |
| print('The prompt is:',prompt) | |
| return prompt.content | |
| def generate_image(final_story, business_details, refined_ideation): | |
| prompt = generate_prompt(final_story, business_details, refined_ideation) | |
| print('************Finished prompt generator****************') | |
| client = InferenceClient( | |
| provider="hf-inference", | |
| api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'), | |
| ) | |
| print('************Finished calling generator****************') | |
| # output is a PIL.Image object | |
| image = client.text_to_image( | |
| prompt, | |
| model="black-forest-labs/FLUX.1-schnell", | |
| ) | |
| print('*****************Image Created*******************') | |
| # Convert image to BytesIO buffer | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") # you can also use "JPEG" if preferred | |
| buffered.seek(0) | |
| # Encode to base64 string | |
| img_base64 = base64.b64encode(buffered.read()).decode("utf-8") | |
| print('*****************Image Encoded to Base64*******************') | |
| return img_base64 | |
| def save_to_db(business_details): | |
| dataset = load_dataset("subashdvorak/tiktok-agentic-story")['train'] | |
| # dataset = load_influencer_data() | |
| df = pd.DataFrame(dataset) | |
| # 2. Flatten all business detail values to a set of lowercase strings | |
| all_values = set() | |
| for v in business_details.values(): | |
| if isinstance(v, str): | |
| all_values.add(v.lower()) | |
| elif isinstance(v, list): | |
| all_values.update(map(str.lower, map(str, v))) | |
| # 3. Match rows where ANY column contains ANY of the values | |
| def row_matches(row): | |
| return any( | |
| str(cell).lower().find(val) != -1 | |
| for cell in row | |
| for val in all_values | |
| ) | |
| # 4. Apply row-wise matching | |
| matched_df = df[df.apply(row_matches, axis=1)] | |
| matched_df.to_csv('extracted_data.csv') | |