|
|
import pandas as pd |
|
|
from src.genai.utils.data_loader import caption_df |
|
|
from src.genai.utils.models_loader import llm_gpt |
|
|
from .prompts import details_extract_prompt |
|
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
|
from .state import DetailsFormatter |
|
|
from langsmith import traceable |
|
|
|
|
|
class DetailsExtractorNode: |
|
|
def __init__(self, interactions): |
|
|
self.llm = llm_gpt |
|
|
self.interactions = interactions |
|
|
|
|
|
@traceable(name="details extraction") |
|
|
def run(self): |
|
|
template = details_extract_prompt() |
|
|
messages = [SystemMessage(content=template), HumanMessage(content=str(self.interactions))] |
|
|
response=llm_gpt.with_structured_output(DetailsFormatter).invoke(messages) |
|
|
return response.model_dump() |
|
|
|
|
|
|
|
|
class SaveToDB: |
|
|
def __init__(self, caption_df): |
|
|
self.df = caption_df.drop(columns=['embeddings'], errors='ignore') |
|
|
|
|
|
def _prepare_values(self, business_details): |
|
|
"""Extract lowercase string values from business_details dict.""" |
|
|
all_values = set() |
|
|
for v in business_details.values(): |
|
|
if isinstance(v, str): |
|
|
all_values.add(v.lower()) |
|
|
elif isinstance(v, list): |
|
|
all_values.update(map(str.lower, map(str, v))) |
|
|
return all_values |
|
|
|
|
|
def _row_matches(self, row, all_values): |
|
|
"""Check if any value in all_values exists in the row.""" |
|
|
return any( |
|
|
str(cell).lower().find(val) != -1 |
|
|
for cell in row |
|
|
for val in all_values |
|
|
) |
|
|
|
|
|
def save_to_csv(self, business_details, output_file='extracted_data.csv'): |
|
|
"""Filter dataframe rows based on business_details and save to CSV.""" |
|
|
all_values = self._prepare_values(business_details) |
|
|
matched_df = self.df[self.df.apply(self._row_matches, axis=1, args=(all_values,))] |
|
|
matched_df.to_csv(output_file, index=False) |
|
|
|
|
|
|
|
|
|
|
|
|