Spaces:
Sleeping
Sleeping
File size: 13,901 Bytes
ded29b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | import csv
import json
import time
import os
import pandas as pd
import io
import google.generativeai as genai
from typing import List, Dict, Any, Optional, Tuple
import logging
import sys
import traceback
from .api_manager import ApiKeyManager
# Configure logging to stdout only
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Processing constants
BATCH_SIZE = 10 # Number of items to process in each batch
BATCH_DELAY = 30 # Delay between batches in seconds
MAX_RETRIES = 3 # Maximum number of retries for API calls
RETRY_DELAY = 5 # Delay between retries in seconds
# Predefined list of niches
NICHES_LIST = ["Animals", "Plants", "Nature & Landscapes", "People", "Occupations",
"Fashion & Style", "Dance", "Toys", "Health & Fitness",
"Mindfulness & Spirituality", "Art", "Coloring", "Patterns & Textures",
"Crafts & DIY", "Decor", "Cartoons", "Anime & Manga",
"Fantasy & Whimsical", "Mythology & Legends", "Characters",
"Music", "Movies & TV", "Games & Puzzles", "Sports", "Travel",
"Countries", "Cities & Architecture", "History", "Culture",
"Religion", "Science", "Education", "Space", "Food & Drink",
"Holidays & Seasonal", "Horror", "Humor & Funny",
"Novelty & Gag Gifts", "Horoscope & Astrology", "Vehicles",
"Quotes & Affirmations", "Adult Themes"]
def configure_genai(api_key: str) -> None:
"""Configure the Gemini API with the given API key"""
try:
genai.configure(api_key=api_key)
logger.debug("Gemini API configured with key")
return True
except Exception as e:
logger.error(f"Error configuring Gemini API: {str(e)}")
return False
def analyze_title_with_gemini(
title: str,
subtitle: str,
key_manager: ApiKeyManager
) -> Tuple[Dict[str, Any], bool]:
"""
Analyze the title and subtitle using Gemini API to determine niches, subniches, and trademark status
Args:
title: Title text
subtitle: Subtitle text
key_manager: API key manager instance
Returns:
Tuple[Dict, bool]: The result and success flag
"""
# Truncate long titles/subtitles to prevent API errors
title = title[:1000] if title else ""
subtitle = subtitle[:1000] if subtitle else ""
# Format the niches list as a comma-separated string
niches_text = ", ".join(NICHES_LIST)
# Create the prompt for Gemini
prompt = f"""Analyze the following coloring book title and subtitle:
- Title: {title}
- Subtitle: {subtitle}
Please determine:
1. "niches": 1 most appropriate niche based on the following list of niches: {niches_text}
2. "subniches": 2-3 more specific subniches (for example: if the niche is "Animals" then subniche could be "Dogs", "Wildlife", "Pets", etc.)
3. "trademark": "yes" if the title or subtitle contains copyrighted characters/brands (such as Disney, Marvel, DC, Pokemon, etc.), "no" if not
Return the result as JSON in the following format:
{{
"niches": ["niche1"],
"subniches": ["subniche1", "subniche2"],
"trademark": "yes/no"
}}
Return only the JSON, no additional explanation needed. Do not include any text before or after the JSON."""
logger.info(f"Analyzing title: '{title}', subtitle: '{subtitle}'")
for attempt in range(MAX_RETRIES):
try:
# Get the next API key
api_key = key_manager.get_next_api_key()
configure_genai(api_key)
# Create a generative model with specific parameters
logger.info(f"Creating Gemini model, attempt {attempt + 1}/{MAX_RETRIES}")
model = genai.GenerativeModel('gemini-2.0-flash')
# Set generation config
generation_config = {
"temperature": 0.1,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
# Generate content
logger.info(f"Sending request to Gemini API, attempt {attempt + 1}/{MAX_RETRIES}")
start_time = time.time()
response = model.generate_content(
prompt,
generation_config=generation_config
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Received response from Gemini API in {elapsed_time:.2f} seconds")
# Check if response has text
if not hasattr(response, 'text') or not response.text:
logger.warning("Empty response received from API")
raise ValueError("Empty response received from API")
response_text = response.text.strip()
logger.debug(f"Raw response text: {response_text[:100]}...")
# Clean the response if needed
if not response_text.startswith('{'):
logger.warning("Response does not start with '{', attempting to extract JSON")
start_idx = response_text.find('{')
end_idx = response_text.rfind('}')
if start_idx >= 0 and end_idx > start_idx:
response_text = response_text[start_idx:end_idx+1]
logger.info(f"JSON extracted from response: {response_text[:100]}...")
else:
logger.error(f"Could not find valid JSON in response: {response_text[:100]}")
raise ValueError(f"Could not find valid JSON in response: {response_text[:100]}")
# Parse the response as JSON
logger.info("Parsing response as JSON")
result = json.loads(response_text)
# Validate the result
if not isinstance(result, dict):
logger.error(f"Response is not a valid JSON object: {type(result)}")
raise ValueError("Response is not a valid JSON object")
if "niches" not in result or "subniches" not in result or "trademark" not in result:
logger.error(f"Response is missing required fields: {result}")
raise ValueError(f"Response is missing required fields: {result}")
# Ensure niches and subniches are lists
if not isinstance(result["niches"], list):
logger.warning(f"Niches is not a list, converting: {result['niches']}")
result["niches"] = [result["niches"]]
if not isinstance(result["subniches"], list):
logger.warning(f"Subniches is not a list, converting: {result['subniches']}")
result["subniches"] = [result["subniches"]]
# Ensure trademark is a string
if not isinstance(result["trademark"], str):
logger.warning(f"Trademark is not a string, converting: {result['trademark']}")
result["trademark"] = str(result["trademark"]).lower()
logger.info(f"Successfully analyzed title. Niches: {result['niches']}, Subniches: {result['subniches']}, Trademark: {result['trademark']}")
return result, True
except Exception as e:
if "quota" in str(e).lower() or "rate" in str(e).lower() or "limit" in str(e).lower():
logger.warning(f"API key quota exceeded or rate limited: {e}")
key_manager.mark_key_as_failed(api_key)
else:
error_details = traceback.format_exc()
logger.error(f"Error on attempt {attempt + 1}/{MAX_RETRIES}: {str(e)}\n{error_details}")
# More debug info for JSON parsing errors
if "Expecting value" in str(e) and 'response_text' in locals():
logger.error(f"Failed to parse JSON from: {response_text[:100]}...")
if attempt < MAX_RETRIES - 1:
retry_delay = RETRY_DELAY * (attempt + 1) # Progressive backoff
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
# If all attempts failed, return default values
logger.warning(f"All attempts failed for title: {title}, subtitle: {subtitle}")
return {
"niches": ["Unknown"],
"subniches": ["Unknown"],
"trademark": "unknown"
}, False
def categorize_titles(
input_data,
batch_size: int = BATCH_SIZE,
batch_delay: int = BATCH_DELAY,
resume_from: int = 0
) -> Tuple[pd.DataFrame, int, int]:
"""
Process a CSV file of titles and categorize them using the Gemini AI model
Args:
input_data: File-like object containing the CSV data
batch_size: Number of rows to process in each batch
batch_delay: Delay between batches in seconds
resume_from: Row index to resume processing from (0-based)
Returns:
Tuple containing:
- DataFrame with processed data
- Number of successfully processed rows
- Number of failed rows
"""
try:
logger.info(f"Starting categorization process with batch size={batch_size}, batch_delay={batch_delay}")
# Initialize API key manager
key_manager = ApiKeyManager()
logger.info(f"Initialized API key manager with {len(key_manager.api_keys)} keys")
# Read input CSV file as a DataFrame
df = pd.read_csv(input_data)
original_rows = len(df)
logger.info(f"Read {original_rows} rows from input")
# Add columns for categorization results if they don't exist
if 'Niche' not in df.columns:
df['Niche'] = ""
if 'Subniches' not in df.columns:
df['Subniches'] = ""
if 'Trademark' not in df.columns:
df['Trademark'] = ""
if 'Categorized' not in df.columns:
df['Categorized'] = False
# Count rows that have already been categorized
already_processed = df['Categorized'].sum()
logger.info(f"Found {already_processed} rows already categorized")
# Process rows in batches
successful_rows = 0
failed_rows = 0
total_to_process = len(df) - already_processed
for i, row in df.iloc[resume_from:].iterrows():
# Skip already processed rows
if row['Categorized']:
logger.info(f"Skipping already processed row {i+1}/{len(df)}")
continue
try:
logger.info(f"Processing row {i+1}/{len(df)}: {row['Title']}")
# Get title and subtitle
title = str(row['Title']) if not pd.isna(row['Title']) else ""
subtitle = str(row['Subtitle']) if 'Subtitle' in row and not pd.isna(row['Subtitle']) else ""
# Analyze title with Gemini
result, success = analyze_title_with_gemini(title, subtitle, key_manager)
if success:
# Update row with results
df.at[i, 'Niche'] = result['niches'][0] if result['niches'] else "Unknown"
df.at[i, 'Subniches'] = ", ".join(result['subniches']) if result['subniches'] else "Unknown"
df.at[i, 'Trademark'] = result['trademark'].lower()
df.at[i, 'Categorized'] = True
successful_rows += 1
logger.info(f"Successfully categorized row {i+1}")
else:
failed_rows += 1
logger.warning(f"Failed to categorize row {i+1}")
# Save progress periodically
if (i + 1) % batch_size == 0 or i == len(df) - 1:
progress_pct = (successful_rows / total_to_process) * 100 if total_to_process > 0 else 100
logger.info(f"Progress: {successful_rows}/{total_to_process} ({progress_pct:.1f}%) categorized")
logger.info(f"API key status: {key_manager.get_api_keys_status()['available_keys']} keys available")
# Check if we have any working keys left
if not key_manager.has_working_keys():
logger.error("No working API keys left. Saving progress and stopping.")
break
# Apply batch delay (except on the last item)
if i < len(df) - 1 and batch_delay > 0:
logger.info(f"Batch complete. Waiting for {batch_delay} seconds before next batch...")
time.sleep(batch_delay)
except Exception as e:
logger.error(f"Error processing row {i+1}: {str(e)}")
failed_rows += 1
# Calculate final statistics
logger.info("Categorization process complete")
logger.info(f"Total rows: {original_rows}")
logger.info(f"Previously categorized: {already_processed}")
logger.info(f"Newly processed: {successful_rows + failed_rows}")
logger.info(f"Successfully categorized: {successful_rows}")
logger.info(f"Failed to categorize: {failed_rows}")
return df, successful_rows, failed_rows
except Exception as e:
logger.error(f"Error in categorization process: {str(e)}")
raise |