Spaces:
Paused
Paused
| import asyncio | |
| import math | |
| from typing import Any, Dict, List, Optional, Union | |
| from starfish import StructuredLLM | |
| async def generate_topics( | |
| user_instruction: str, | |
| num_topics: int, | |
| model_name: str = "openai/gpt-4o-mini", | |
| model_kwargs: Optional[Dict[str, Any]] = None, | |
| existing_topics: Optional[List[str]] = None, | |
| ) -> List[str]: | |
| """Generate unique topics based on user instructions using a StructuredLLM model.""" | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| if "temperature" not in model_kwargs: | |
| model_kwargs["temperature"] = 1 | |
| existing_topics = existing_topics or [] | |
| if num_topics <= 0: | |
| return [] | |
| # Calculate batches needed (5 topics per batch) | |
| llm_batch_size = 5 | |
| num_batches = math.ceil(num_topics / llm_batch_size) | |
| generated_topics = [] | |
| for _ in range(num_batches): | |
| topic_generator = StructuredLLM( | |
| model_name=model_name, | |
| prompt="""Can you generate a list of topics about {{user_instruction}} | |
| {% if existing_topics_str %} | |
| Please do not generate topics that are already in the list: {{existing_topics_str}} | |
| Make sure the topics are unique and vary from each other | |
| {% endif %} | |
| """, | |
| output_schema=[{"name": "topic", "type": "str"}], | |
| model_kwargs=model_kwargs, | |
| ) | |
| all_existing = existing_topics + generated_topics | |
| input_params = {"user_instruction": user_instruction, "num_records": min(llm_batch_size, num_topics - len(generated_topics))} | |
| if all_existing: | |
| input_params["existing_topics_str"] = ",".join(all_existing) | |
| topic_response = await topic_generator.run(**input_params) | |
| topic_data = [item.get("topic") for item in topic_response.data] | |
| generated_topics.extend(topic_data) | |
| if len(generated_topics) >= num_topics: | |
| break | |
| return generated_topics | |
| async def prepare_topic( | |
| topics: Optional[List[Union[str, Dict[str, int]]]] = None, | |
| num_records: Optional[int] = None, | |
| records_per_topic: int = 20, | |
| user_instruction: Optional[str] = None, | |
| model_name: str = "openai/gpt-4o-mini", | |
| model_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> List[Dict[str, str]]: | |
| """Split records into topics, generating topics if none are provided or if needed. | |
| Supported input formats: | |
| 1. String list: ['topic1', 'topic2'] - Topics with equal or calculated distribution | |
| 2. Dict list: [{'topic1': 20}, {'topic2': 30}] - Topics with specific counts | |
| 3. Mixed: ['topic1', {'topic2': 30}] - Combination of both formats | |
| 4. None: No topics provided, will generate based on user_instruction | |
| Args: | |
| topics: Optional list of topics, either strings or {topic: count} dicts | |
| num_records: Total number of records to split (required for dict topics or None topics) | |
| records_per_topic: Number of records per topic (default: 20) | |
| user_instruction: Topic generation instructions (required if topics is None) | |
| model_name: Model name for topic generation | |
| model_kwargs: Model kwargs for topic generation | |
| Returns: | |
| List of {'topic': topic_name} dictionaries, with one entry per record | |
| """ | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| if "temperature" not in model_kwargs: | |
| model_kwargs["temperature"] = 1 | |
| # --- STEP 1: Input validation and normalization --- | |
| if topics is None: | |
| # Must have num_records and user_instruction if no topics provided | |
| if not num_records or num_records <= 0: | |
| raise ValueError("num_records must be positive when topics are not provided") | |
| if not user_instruction: | |
| raise ValueError("user_instruction required when topics are not provided") | |
| topic_assignments = [] | |
| else: | |
| # Validate topics is a non-empty list | |
| if not isinstance(topics, list) or not topics: | |
| raise ValueError("topics must be a non-empty list") | |
| # Convert all topic inputs to a standardized [(topic_name, count)] list | |
| # For string topics: count will be None (to be calculated later) | |
| # For dict topics: use the specified count | |
| topic_assignments = [] | |
| seen_topics = set() | |
| for topic in topics: | |
| if isinstance(topic, str): | |
| if topic not in seen_topics: | |
| topic_assignments.append((topic, None)) | |
| seen_topics.add(topic) | |
| elif isinstance(topic, dict) and len(topic) == 1: | |
| topic_name = next(iter(topic)) | |
| count = topic[topic_name] | |
| if not isinstance(count, int) or count < 0: | |
| raise ValueError(f"Topic '{topic_name}' has invalid count {count}") | |
| if topic_name not in seen_topics: | |
| topic_assignments.append((topic_name, count)) | |
| seen_topics.add(topic_name) | |
| else: | |
| raise ValueError("Topics must be strings or single-key dictionaries") | |
| # --- STEP 2: Calculate or validate counts for provided topics --- | |
| result = [] | |
| assigned_count = 0 | |
| topic_names = [] # Track all assigned topic names | |
| if topic_assignments: | |
| # Handle string topics with no count (None) - assign counts based on input | |
| string_topics = [(name, count) for name, count in topic_assignments if count is None] | |
| dict_topics = [(name, count) for name, count in topic_assignments if count is not None] | |
| # Case: String topics with no num_records - assign records_per_topic to each | |
| if string_topics and num_records is None: | |
| for name, _ in string_topics: | |
| result.append({name: records_per_topic}) | |
| topic_names.append(name) | |
| assigned_count += records_per_topic | |
| # Case: String topics with num_records - distribute evenly | |
| elif string_topics and num_records is not None: | |
| remaining = num_records - sum(count for _, count in dict_topics if count is not None) | |
| if remaining < 0: | |
| raise ValueError("Dict topic counts exceed num_records") | |
| # Distribute remaining records among string topics | |
| if string_topics and remaining > 0: | |
| base = remaining // len(string_topics) | |
| extra = remaining % len(string_topics) | |
| for i, (name, _) in enumerate(string_topics): | |
| count = base + (1 if i < extra else 0) | |
| if count > 0: | |
| result.append({name: count}) | |
| topic_names.append(name) | |
| assigned_count += count | |
| # Add dictionary topics with predefined counts | |
| for name, count in dict_topics: | |
| if count > 0: | |
| result.append({name: count}) | |
| topic_names.append(name) | |
| assigned_count += count | |
| # Validate total count for dictionary topics | |
| if dict_topics and num_records is None: | |
| raise ValueError("num_records required when using dictionary topics") | |
| if num_records is not None and assigned_count > num_records: | |
| raise ValueError(f"Total assigned count ({assigned_count}) exceeds num_records ({num_records})") | |
| # --- STEP 3: Generate topics for remaining records if needed --- | |
| remaining_records = 0 if num_records is None else num_records - assigned_count | |
| if remaining_records > 0: | |
| if records_per_topic <= 0: | |
| raise ValueError("records_per_topic must be positive when generating topics") | |
| # Generate topics with LLM if instructions provided | |
| if user_instruction: | |
| topics_needed = math.ceil(remaining_records / records_per_topic) | |
| generated = await generate_topics( | |
| user_instruction=user_instruction, num_topics=topics_needed, model_name=model_name, model_kwargs=model_kwargs, existing_topics=topic_names | |
| ) | |
| # Assign counts to generated topics | |
| for topic in generated: | |
| if topic in topic_names: # Skip if duplicate (shouldn't happen with proper LLM) | |
| print(f"Skipping duplicate generated topic: {topic}") | |
| continue | |
| count = min(records_per_topic, remaining_records) | |
| if count <= 0: | |
| break | |
| result.append({topic: count}) | |
| topic_names.append(topic) | |
| remaining_records -= count | |
| assigned_count += count | |
| # Generate auto-topics for any still-remaining records | |
| auto_index = 1 | |
| while remaining_records > 0: | |
| # Find next available auto_topic name | |
| auto_name = f"auto_topic{auto_index}" | |
| while auto_name in topic_names: | |
| auto_index += 1 | |
| auto_name = f"auto_topic{auto_index}" | |
| count = min(records_per_topic, remaining_records) | |
| result.append({auto_name: count}) | |
| topic_names.append(auto_name) | |
| remaining_records -= count | |
| assigned_count += count | |
| auto_index += 1 | |
| # Final validation | |
| if num_records is not None and assigned_count != num_records: | |
| print(f"Warning: Assigned {assigned_count} records, expected {num_records}") | |
| flatten_topic_list = [] | |
| for item in result: | |
| for key, count in item.items(): | |
| flatten_topic_list.extend([{"topic": key}] * count) | |
| return flatten_topic_list | |
| if __name__ == "__main__": | |
| print("--- Running Examples ---") | |
| # Example 1: Dictionary topics with additional generation | |
| print("\nExample 1: Dictionary topics + generation") | |
| topics1 = [{"topic1": 20}, {"topic2": 30}] | |
| result1 = asyncio.run(prepare_topic(topics=topics1, num_records=100, records_per_topic=25, user_instruction="some context")) | |
| print(f"Result: {result1}") | |
| print(f"Total: {len(result1)}") | |
| # Example 2: String topics with even distribution | |
| print("\nExample 2: String topics with distribution") | |
| topics2 = ["topicA", "topicB", "topicC"] | |
| result2 = asyncio.run(prepare_topic(topics=topics2, num_records=10)) | |
| print(f"Result: {result2}") | |
| print(f"Total: {len(result2)}") | |
| # Example 3: Mixed string and dict topics | |
| print("\nExample 3: Mixed string/dict topics") | |
| topics3 = ["topicX", {"topicY": 10}] | |
| result3 = asyncio.run(prepare_topic(topics=topics3, num_records=30, user_instruction="mixed topics")) | |
| print(f"Result: {result3}") | |
| print(f"Total: {len(result3)}") | |
| # Example 4: String topics with fixed count | |
| print("\nExample 4: String topics with fixed count") | |
| topics4 = ["apple", "banana", "cherry"] | |
| result4 = asyncio.run(prepare_topic(topics=topics4, records_per_topic=15)) | |
| print(f"Result: {result4}") | |
| print(f"Total: {len(result4)}") | |
| # Example 5: No topics, generate all | |
| print("\nExample 5: No topics, generate all") | |
| async def run_example5(): | |
| result = await prepare_topic(topics=None, num_records=10, records_per_topic=5, user_instruction="cloud computing") | |
| print(f"Result: {result}") | |
| print(f"Total: {len(result)}") | |
| asyncio.run(run_example5()) | |
| print("\n--- Examples Finished ---") | |