Spaces:
Sleeping
Sleeping
| import csv | |
| import json | |
| import time | |
| import os | |
| import pandas as pd | |
| import io | |
| import google.generativeai as genai | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import logging | |
| import sys | |
| import traceback | |
| from .api_manager import ApiKeyManager | |
| # Configure logging to stdout only | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Processing constants | |
| BATCH_SIZE = 10 # Number of items to process in each batch | |
| BATCH_DELAY = 30 # Delay between batches in seconds | |
| MAX_RETRIES = 3 # Maximum number of retries for API calls | |
| RETRY_DELAY = 5 # Delay between retries in seconds | |
| # Predefined list of niches | |
| NICHES_LIST = ["Animals", "Plants", "Nature & Landscapes", "People", "Occupations", | |
| "Fashion & Style", "Dance", "Toys", "Health & Fitness", | |
| "Mindfulness & Spirituality", "Art", "Coloring", "Patterns & Textures", | |
| "Crafts & DIY", "Decor", "Cartoons", "Anime & Manga", | |
| "Fantasy & Whimsical", "Mythology & Legends", "Characters", | |
| "Music", "Movies & TV", "Games & Puzzles", "Sports", "Travel", | |
| "Countries", "Cities & Architecture", "History", "Culture", | |
| "Religion", "Science", "Education", "Space", "Food & Drink", | |
| "Holidays & Seasonal", "Horror", "Humor & Funny", | |
| "Novelty & Gag Gifts", "Horoscope & Astrology", "Vehicles", | |
| "Quotes & Affirmations", "Adult Themes"] | |
| def configure_genai(api_key: str) -> None: | |
| """Configure the Gemini API with the given API key""" | |
| try: | |
| genai.configure(api_key=api_key) | |
| logger.debug("Gemini API configured with key") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error configuring Gemini API: {str(e)}") | |
| return False | |
| def analyze_title_with_gemini( | |
| title: str, | |
| subtitle: str, | |
| key_manager: ApiKeyManager | |
| ) -> Tuple[Dict[str, Any], bool]: | |
| """ | |
| Analyze the title and subtitle using Gemini API to determine niches, subniches, and trademark status | |
| Args: | |
| title: Title text | |
| subtitle: Subtitle text | |
| key_manager: API key manager instance | |
| Returns: | |
| Tuple[Dict, bool]: The result and success flag | |
| """ | |
| # Truncate long titles/subtitles to prevent API errors | |
| title = title[:1000] if title else "" | |
| subtitle = subtitle[:1000] if subtitle else "" | |
| # Format the niches list as a comma-separated string | |
| niches_text = ", ".join(NICHES_LIST) | |
| # Create the prompt for Gemini | |
| prompt = f"""Analyze the following coloring book title and subtitle: | |
| - Title: {title} | |
| - Subtitle: {subtitle} | |
| Please determine: | |
| 1. "niches": 1 most appropriate niche based on the following list of niches: {niches_text} | |
| 2. "subniches": 2-3 more specific subniches (for example: if the niche is "Animals" then subniche could be "Dogs", "Wildlife", "Pets", etc.) | |
| 3. "trademark": "yes" if the title or subtitle contains copyrighted characters/brands (such as Disney, Marvel, DC, Pokemon, etc.), "no" if not | |
| Return the result as JSON in the following format: | |
| {{ | |
| "niches": ["niche1"], | |
| "subniches": ["subniche1", "subniche2"], | |
| "trademark": "yes/no" | |
| }} | |
| Return only the JSON, no additional explanation needed. Do not include any text before or after the JSON.""" | |
| logger.info(f"Analyzing title: '{title}', subtitle: '{subtitle}'") | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| # Get the next API key | |
| api_key = key_manager.get_next_api_key() | |
| configure_genai(api_key) | |
| # Create a generative model with specific parameters | |
| logger.info(f"Creating Gemini model, attempt {attempt + 1}/{MAX_RETRIES}") | |
| model = genai.GenerativeModel('gemini-2.0-flash') | |
| # Set generation config | |
| generation_config = { | |
| "temperature": 0.1, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_output_tokens": 1024, | |
| } | |
| # Generate content | |
| logger.info(f"Sending request to Gemini API, attempt {attempt + 1}/{MAX_RETRIES}") | |
| start_time = time.time() | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=generation_config | |
| ) | |
| end_time = time.time() | |
| elapsed_time = end_time - start_time | |
| logger.info(f"Received response from Gemini API in {elapsed_time:.2f} seconds") | |
| # Check if response has text | |
| if not hasattr(response, 'text') or not response.text: | |
| logger.warning("Empty response received from API") | |
| raise ValueError("Empty response received from API") | |
| response_text = response.text.strip() | |
| logger.debug(f"Raw response text: {response_text[:100]}...") | |
| # Clean the response if needed | |
| if not response_text.startswith('{'): | |
| logger.warning("Response does not start with '{', attempting to extract JSON") | |
| start_idx = response_text.find('{') | |
| end_idx = response_text.rfind('}') | |
| if start_idx >= 0 and end_idx > start_idx: | |
| response_text = response_text[start_idx:end_idx+1] | |
| logger.info(f"JSON extracted from response: {response_text[:100]}...") | |
| else: | |
| logger.error(f"Could not find valid JSON in response: {response_text[:100]}") | |
| raise ValueError(f"Could not find valid JSON in response: {response_text[:100]}") | |
| # Parse the response as JSON | |
| logger.info("Parsing response as JSON") | |
| result = json.loads(response_text) | |
| # Validate the result | |
| if not isinstance(result, dict): | |
| logger.error(f"Response is not a valid JSON object: {type(result)}") | |
| raise ValueError("Response is not a valid JSON object") | |
| if "niches" not in result or "subniches" not in result or "trademark" not in result: | |
| logger.error(f"Response is missing required fields: {result}") | |
| raise ValueError(f"Response is missing required fields: {result}") | |
| # Ensure niches and subniches are lists | |
| if not isinstance(result["niches"], list): | |
| logger.warning(f"Niches is not a list, converting: {result['niches']}") | |
| result["niches"] = [result["niches"]] | |
| if not isinstance(result["subniches"], list): | |
| logger.warning(f"Subniches is not a list, converting: {result['subniches']}") | |
| result["subniches"] = [result["subniches"]] | |
| # Ensure trademark is a string | |
| if not isinstance(result["trademark"], str): | |
| logger.warning(f"Trademark is not a string, converting: {result['trademark']}") | |
| result["trademark"] = str(result["trademark"]).lower() | |
| logger.info(f"Successfully analyzed title. Niches: {result['niches']}, Subniches: {result['subniches']}, Trademark: {result['trademark']}") | |
| return result, True | |
| except Exception as e: | |
| if "quota" in str(e).lower() or "rate" in str(e).lower() or "limit" in str(e).lower(): | |
| logger.warning(f"API key quota exceeded or rate limited: {e}") | |
| key_manager.mark_key_as_failed(api_key) | |
| else: | |
| error_details = traceback.format_exc() | |
| logger.error(f"Error on attempt {attempt + 1}/{MAX_RETRIES}: {str(e)}\n{error_details}") | |
| # More debug info for JSON parsing errors | |
| if "Expecting value" in str(e) and 'response_text' in locals(): | |
| logger.error(f"Failed to parse JSON from: {response_text[:100]}...") | |
| if attempt < MAX_RETRIES - 1: | |
| retry_delay = RETRY_DELAY * (attempt + 1) # Progressive backoff | |
| logger.info(f"Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| # If all attempts failed, return default values | |
| logger.warning(f"All attempts failed for title: {title}, subtitle: {subtitle}") | |
| return { | |
| "niches": ["Unknown"], | |
| "subniches": ["Unknown"], | |
| "trademark": "unknown" | |
| }, False | |
| def categorize_titles( | |
| input_data, | |
| batch_size: int = BATCH_SIZE, | |
| batch_delay: int = BATCH_DELAY, | |
| resume_from: int = 0 | |
| ) -> Tuple[pd.DataFrame, int, int]: | |
| """ | |
| Process a CSV file of titles and categorize them using the Gemini AI model | |
| Args: | |
| input_data: File-like object containing the CSV data | |
| batch_size: Number of rows to process in each batch | |
| batch_delay: Delay between batches in seconds | |
| resume_from: Row index to resume processing from (0-based) | |
| Returns: | |
| Tuple containing: | |
| - DataFrame with processed data | |
| - Number of successfully processed rows | |
| - Number of failed rows | |
| """ | |
| try: | |
| logger.info(f"Starting categorization process with batch size={batch_size}, batch_delay={batch_delay}") | |
| # Initialize API key manager | |
| key_manager = ApiKeyManager() | |
| logger.info(f"Initialized API key manager with {len(key_manager.api_keys)} keys") | |
| # Read input CSV file as a DataFrame | |
| df = pd.read_csv(input_data) | |
| original_rows = len(df) | |
| logger.info(f"Read {original_rows} rows from input") | |
| # Add columns for categorization results if they don't exist | |
| if 'Niche' not in df.columns: | |
| df['Niche'] = "" | |
| if 'Subniches' not in df.columns: | |
| df['Subniches'] = "" | |
| if 'Trademark' not in df.columns: | |
| df['Trademark'] = "" | |
| if 'Categorized' not in df.columns: | |
| df['Categorized'] = False | |
| # Count rows that have already been categorized | |
| already_processed = df['Categorized'].sum() | |
| logger.info(f"Found {already_processed} rows already categorized") | |
| # Process rows in batches | |
| successful_rows = 0 | |
| failed_rows = 0 | |
| total_to_process = len(df) - already_processed | |
| for i, row in df.iloc[resume_from:].iterrows(): | |
| # Skip already processed rows | |
| if row['Categorized']: | |
| logger.info(f"Skipping already processed row {i+1}/{len(df)}") | |
| continue | |
| try: | |
| logger.info(f"Processing row {i+1}/{len(df)}: {row['Title']}") | |
| # Get title and subtitle | |
| title = str(row['Title']) if not pd.isna(row['Title']) else "" | |
| subtitle = str(row['Subtitle']) if 'Subtitle' in row and not pd.isna(row['Subtitle']) else "" | |
| # Analyze title with Gemini | |
| result, success = analyze_title_with_gemini(title, subtitle, key_manager) | |
| if success: | |
| # Update row with results | |
| df.at[i, 'Niche'] = result['niches'][0] if result['niches'] else "Unknown" | |
| df.at[i, 'Subniches'] = ", ".join(result['subniches']) if result['subniches'] else "Unknown" | |
| df.at[i, 'Trademark'] = result['trademark'].lower() | |
| df.at[i, 'Categorized'] = True | |
| successful_rows += 1 | |
| logger.info(f"Successfully categorized row {i+1}") | |
| else: | |
| failed_rows += 1 | |
| logger.warning(f"Failed to categorize row {i+1}") | |
| # Save progress periodically | |
| if (i + 1) % batch_size == 0 or i == len(df) - 1: | |
| progress_pct = (successful_rows / total_to_process) * 100 if total_to_process > 0 else 100 | |
| logger.info(f"Progress: {successful_rows}/{total_to_process} ({progress_pct:.1f}%) categorized") | |
| logger.info(f"API key status: {key_manager.get_api_keys_status()['available_keys']} keys available") | |
| # Check if we have any working keys left | |
| if not key_manager.has_working_keys(): | |
| logger.error("No working API keys left. Saving progress and stopping.") | |
| break | |
| # Apply batch delay (except on the last item) | |
| if i < len(df) - 1 and batch_delay > 0: | |
| logger.info(f"Batch complete. Waiting for {batch_delay} seconds before next batch...") | |
| time.sleep(batch_delay) | |
| except Exception as e: | |
| logger.error(f"Error processing row {i+1}: {str(e)}") | |
| failed_rows += 1 | |
| # Calculate final statistics | |
| logger.info("Categorization process complete") | |
| logger.info(f"Total rows: {original_rows}") | |
| logger.info(f"Previously categorized: {already_processed}") | |
| logger.info(f"Newly processed: {successful_rows + failed_rows}") | |
| logger.info(f"Successfully categorized: {successful_rows}") | |
| logger.info(f"Failed to categorize: {failed_rows}") | |
| return df, successful_rows, failed_rows | |
| except Exception as e: | |
| logger.error(f"Error in categorization process: {str(e)}") | |
| raise |