AI-Agent-Book / utils /categorizer.py
Cuong2004's picture
init project
ded29b0
import csv
import json
import time
import os
import pandas as pd
import io
import google.generativeai as genai
from typing import List, Dict, Any, Optional, Tuple
import logging
import sys
import traceback
from .api_manager import ApiKeyManager
# Configure logging to stdout only
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Processing constants
BATCH_SIZE = 10 # Number of items to process in each batch
BATCH_DELAY = 30 # Delay between batches in seconds
MAX_RETRIES = 3 # Maximum number of retries for API calls
RETRY_DELAY = 5 # Delay between retries in seconds
# Predefined list of niches
NICHES_LIST = ["Animals", "Plants", "Nature & Landscapes", "People", "Occupations",
"Fashion & Style", "Dance", "Toys", "Health & Fitness",
"Mindfulness & Spirituality", "Art", "Coloring", "Patterns & Textures",
"Crafts & DIY", "Decor", "Cartoons", "Anime & Manga",
"Fantasy & Whimsical", "Mythology & Legends", "Characters",
"Music", "Movies & TV", "Games & Puzzles", "Sports", "Travel",
"Countries", "Cities & Architecture", "History", "Culture",
"Religion", "Science", "Education", "Space", "Food & Drink",
"Holidays & Seasonal", "Horror", "Humor & Funny",
"Novelty & Gag Gifts", "Horoscope & Astrology", "Vehicles",
"Quotes & Affirmations", "Adult Themes"]
def configure_genai(api_key: str) -> None:
"""Configure the Gemini API with the given API key"""
try:
genai.configure(api_key=api_key)
logger.debug("Gemini API configured with key")
return True
except Exception as e:
logger.error(f"Error configuring Gemini API: {str(e)}")
return False
def analyze_title_with_gemini(
title: str,
subtitle: str,
key_manager: ApiKeyManager
) -> Tuple[Dict[str, Any], bool]:
"""
Analyze the title and subtitle using Gemini API to determine niches, subniches, and trademark status
Args:
title: Title text
subtitle: Subtitle text
key_manager: API key manager instance
Returns:
Tuple[Dict, bool]: The result and success flag
"""
# Truncate long titles/subtitles to prevent API errors
title = title[:1000] if title else ""
subtitle = subtitle[:1000] if subtitle else ""
# Format the niches list as a comma-separated string
niches_text = ", ".join(NICHES_LIST)
# Create the prompt for Gemini
prompt = f"""Analyze the following coloring book title and subtitle:
- Title: {title}
- Subtitle: {subtitle}
Please determine:
1. "niches": 1 most appropriate niche based on the following list of niches: {niches_text}
2. "subniches": 2-3 more specific subniches (for example: if the niche is "Animals" then subniche could be "Dogs", "Wildlife", "Pets", etc.)
3. "trademark": "yes" if the title or subtitle contains copyrighted characters/brands (such as Disney, Marvel, DC, Pokemon, etc.), "no" if not
Return the result as JSON in the following format:
{{
"niches": ["niche1"],
"subniches": ["subniche1", "subniche2"],
"trademark": "yes/no"
}}
Return only the JSON, no additional explanation needed. Do not include any text before or after the JSON."""
logger.info(f"Analyzing title: '{title}', subtitle: '{subtitle}'")
for attempt in range(MAX_RETRIES):
try:
# Get the next API key
api_key = key_manager.get_next_api_key()
configure_genai(api_key)
# Create a generative model with specific parameters
logger.info(f"Creating Gemini model, attempt {attempt + 1}/{MAX_RETRIES}")
model = genai.GenerativeModel('gemini-2.0-flash')
# Set generation config
generation_config = {
"temperature": 0.1,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
# Generate content
logger.info(f"Sending request to Gemini API, attempt {attempt + 1}/{MAX_RETRIES}")
start_time = time.time()
response = model.generate_content(
prompt,
generation_config=generation_config
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Received response from Gemini API in {elapsed_time:.2f} seconds")
# Check if response has text
if not hasattr(response, 'text') or not response.text:
logger.warning("Empty response received from API")
raise ValueError("Empty response received from API")
response_text = response.text.strip()
logger.debug(f"Raw response text: {response_text[:100]}...")
# Clean the response if needed
if not response_text.startswith('{'):
logger.warning("Response does not start with '{', attempting to extract JSON")
start_idx = response_text.find('{')
end_idx = response_text.rfind('}')
if start_idx >= 0 and end_idx > start_idx:
response_text = response_text[start_idx:end_idx+1]
logger.info(f"JSON extracted from response: {response_text[:100]}...")
else:
logger.error(f"Could not find valid JSON in response: {response_text[:100]}")
raise ValueError(f"Could not find valid JSON in response: {response_text[:100]}")
# Parse the response as JSON
logger.info("Parsing response as JSON")
result = json.loads(response_text)
# Validate the result
if not isinstance(result, dict):
logger.error(f"Response is not a valid JSON object: {type(result)}")
raise ValueError("Response is not a valid JSON object")
if "niches" not in result or "subniches" not in result or "trademark" not in result:
logger.error(f"Response is missing required fields: {result}")
raise ValueError(f"Response is missing required fields: {result}")
# Ensure niches and subniches are lists
if not isinstance(result["niches"], list):
logger.warning(f"Niches is not a list, converting: {result['niches']}")
result["niches"] = [result["niches"]]
if not isinstance(result["subniches"], list):
logger.warning(f"Subniches is not a list, converting: {result['subniches']}")
result["subniches"] = [result["subniches"]]
# Ensure trademark is a string
if not isinstance(result["trademark"], str):
logger.warning(f"Trademark is not a string, converting: {result['trademark']}")
result["trademark"] = str(result["trademark"]).lower()
logger.info(f"Successfully analyzed title. Niches: {result['niches']}, Subniches: {result['subniches']}, Trademark: {result['trademark']}")
return result, True
except Exception as e:
if "quota" in str(e).lower() or "rate" in str(e).lower() or "limit" in str(e).lower():
logger.warning(f"API key quota exceeded or rate limited: {e}")
key_manager.mark_key_as_failed(api_key)
else:
error_details = traceback.format_exc()
logger.error(f"Error on attempt {attempt + 1}/{MAX_RETRIES}: {str(e)}\n{error_details}")
# More debug info for JSON parsing errors
if "Expecting value" in str(e) and 'response_text' in locals():
logger.error(f"Failed to parse JSON from: {response_text[:100]}...")
if attempt < MAX_RETRIES - 1:
retry_delay = RETRY_DELAY * (attempt + 1) # Progressive backoff
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
# If all attempts failed, return default values
logger.warning(f"All attempts failed for title: {title}, subtitle: {subtitle}")
return {
"niches": ["Unknown"],
"subniches": ["Unknown"],
"trademark": "unknown"
}, False
def categorize_titles(
input_data,
batch_size: int = BATCH_SIZE,
batch_delay: int = BATCH_DELAY,
resume_from: int = 0
) -> Tuple[pd.DataFrame, int, int]:
"""
Process a CSV file of titles and categorize them using the Gemini AI model
Args:
input_data: File-like object containing the CSV data
batch_size: Number of rows to process in each batch
batch_delay: Delay between batches in seconds
resume_from: Row index to resume processing from (0-based)
Returns:
Tuple containing:
- DataFrame with processed data
- Number of successfully processed rows
- Number of failed rows
"""
try:
logger.info(f"Starting categorization process with batch size={batch_size}, batch_delay={batch_delay}")
# Initialize API key manager
key_manager = ApiKeyManager()
logger.info(f"Initialized API key manager with {len(key_manager.api_keys)} keys")
# Read input CSV file as a DataFrame
df = pd.read_csv(input_data)
original_rows = len(df)
logger.info(f"Read {original_rows} rows from input")
# Add columns for categorization results if they don't exist
if 'Niche' not in df.columns:
df['Niche'] = ""
if 'Subniches' not in df.columns:
df['Subniches'] = ""
if 'Trademark' not in df.columns:
df['Trademark'] = ""
if 'Categorized' not in df.columns:
df['Categorized'] = False
# Count rows that have already been categorized
already_processed = df['Categorized'].sum()
logger.info(f"Found {already_processed} rows already categorized")
# Process rows in batches
successful_rows = 0
failed_rows = 0
total_to_process = len(df) - already_processed
for i, row in df.iloc[resume_from:].iterrows():
# Skip already processed rows
if row['Categorized']:
logger.info(f"Skipping already processed row {i+1}/{len(df)}")
continue
try:
logger.info(f"Processing row {i+1}/{len(df)}: {row['Title']}")
# Get title and subtitle
title = str(row['Title']) if not pd.isna(row['Title']) else ""
subtitle = str(row['Subtitle']) if 'Subtitle' in row and not pd.isna(row['Subtitle']) else ""
# Analyze title with Gemini
result, success = analyze_title_with_gemini(title, subtitle, key_manager)
if success:
# Update row with results
df.at[i, 'Niche'] = result['niches'][0] if result['niches'] else "Unknown"
df.at[i, 'Subniches'] = ", ".join(result['subniches']) if result['subniches'] else "Unknown"
df.at[i, 'Trademark'] = result['trademark'].lower()
df.at[i, 'Categorized'] = True
successful_rows += 1
logger.info(f"Successfully categorized row {i+1}")
else:
failed_rows += 1
logger.warning(f"Failed to categorize row {i+1}")
# Save progress periodically
if (i + 1) % batch_size == 0 or i == len(df) - 1:
progress_pct = (successful_rows / total_to_process) * 100 if total_to_process > 0 else 100
logger.info(f"Progress: {successful_rows}/{total_to_process} ({progress_pct:.1f}%) categorized")
logger.info(f"API key status: {key_manager.get_api_keys_status()['available_keys']} keys available")
# Check if we have any working keys left
if not key_manager.has_working_keys():
logger.error("No working API keys left. Saving progress and stopping.")
break
# Apply batch delay (except on the last item)
if i < len(df) - 1 and batch_delay > 0:
logger.info(f"Batch complete. Waiting for {batch_delay} seconds before next batch...")
time.sleep(batch_delay)
except Exception as e:
logger.error(f"Error processing row {i+1}: {str(e)}")
failed_rows += 1
# Calculate final statistics
logger.info("Categorization process complete")
logger.info(f"Total rows: {original_rows}")
logger.info(f"Previously categorized: {already_processed}")
logger.info(f"Newly processed: {successful_rows + failed_rows}")
logger.info(f"Successfully categorized: {successful_rows}")
logger.info(f"Failed to categorize: {failed_rows}")
return df, successful_rows, failed_rows
except Exception as e:
logger.error(f"Error in categorization process: {str(e)}")
raise