Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| NFL Rulebook Training Data Generator | |
| This script processes the 2024 NFL rulebook CSV file and generates | |
| training data for fine-tuning using our Hugging Face model. | |
| For each rule, it generates 3 user/assistant prompt pairs using | |
| the deployed model, then formats them into JSONL for fine-tuning. | |
| """ | |
| import csv | |
| import json | |
| import random | |
| import requests | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import logging | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('nfl_training_data.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| HUGGINGFACE_SPACE_URL = "https://david167-question-generation-api.hf.space" | |
| SYSTEM_MESSAGE = "You are a football broadcaster with years of experience and inside knowledge of the game from playing and coaching. You have a complete understanding of the rule book, how it's interpreted and judged." | |
| class NFLTrainingDataGenerator: | |
| def __init__(self, csv_file_path: str, output_dir: str = "output"): | |
| self.csv_file_path = Path(csv_file_path) | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(exist_ok=True) | |
| # API client setup | |
| self.api_base_url = HUGGINGFACE_SPACE_URL | |
| self.session = requests.Session() | |
| self.session.headers.update({ | |
| 'Content-Type': 'application/json', | |
| 'User-Agent': 'NFL-Training-Data-Generator/1.0' | |
| }) | |
| # Stats tracking | |
| self.stats = { | |
| 'rules_processed': 0, | |
| 'prompts_generated': 0, | |
| 'api_calls_made': 0, | |
| 'errors': 0 | |
| } | |
| def load_rulebook_csv(self) -> List[Dict[str, str]]: | |
| """Load the NFL rulebook CSV file""" | |
| try: | |
| rules = [] | |
| with open(self.csv_file_path, 'r', encoding='utf-8') as file: | |
| reader = csv.DictReader(file) | |
| for row in reader: | |
| rules.append(row) | |
| logger.info(f"Loaded {len(rules)} rules from {self.csv_file_path}") | |
| return rules | |
| except FileNotFoundError: | |
| logger.error(f"CSV file not found: {self.csv_file_path}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading CSV: {str(e)}") | |
| raise | |
| def generate_prompts_for_rule(self, rule_text: str, rule_number: str = None) -> List[Dict[str, Any]]: | |
| """Generate 3 user/assistant prompts for a single rule using our HF model""" | |
| # Create the prompt for the model to generate training examples | |
| generation_prompt = f"""Based on this NFL rule, create 3 different realistic user questions that a football fan, coach, or player might ask, along with expert broadcaster responses. | |
| NFL Rule: {rule_text} | |
| For each of the 3 examples, provide: | |
| 1. A realistic user question about this rule | |
| 2. A detailed, authoritative response as an experienced football broadcaster | |
| Make the questions varied - some should be basic understanding, others about specific scenarios or edge cases. | |
| Make the responses detailed, authoritative, and include practical examples when helpful. | |
| Format as: | |
| Q1: [user question 1] | |
| A1: [detailed broadcaster response 1] | |
| Q2: [user question 2] | |
| A2: [detailed broadcaster response 2] | |
| Q3: [user question 3] | |
| A3: [detailed broadcaster response 3]""" | |
| try: | |
| # Call our HF model API | |
| response = self.call_hf_model(generation_prompt) | |
| self.stats['api_calls_made'] += 1 | |
| if not response: | |
| logger.warning(f"Empty response for rule {rule_number}") | |
| return [] | |
| # Parse the response to extract Q&A pairs | |
| prompts = self.parse_qa_response(response, rule_text) | |
| self.stats['prompts_generated'] += len(prompts) | |
| logger.info(f"Generated {len(prompts)} prompts for rule {rule_number}") | |
| return prompts | |
| except Exception as e: | |
| logger.error(f"Error generating prompts for rule {rule_number}: {str(e)}") | |
| self.stats['errors'] += 1 | |
| return [] | |
| def generate_mock_response(self, prompt: str) -> str: | |
| """Generate a mock response for testing when HF space is unavailable""" | |
| # Extract rule text from the prompt | |
| rule_text = "" | |
| if "NFL Rule:" in prompt: | |
| lines = prompt.split('\n') | |
| for line in lines: | |
| if line.startswith("NFL Rule:"): | |
| rule_text = line.replace("NFL Rule:", "").strip() | |
| break | |
| # Generate realistic mock Q&A based on the rule | |
| mock_responses = [ | |
| f"""Q1: What does this rule mean in simple terms? | |
| A1: This rule explains that {rule_text[:50]}... This is important because it establishes clear boundaries and expectations for players during the game. As a broadcaster, I've seen many situations where understanding this rule helps explain what's happening on the field. | |
| Q2: When would this rule typically come into play during a game? | |
| A2: You'll most commonly see this rule applied during crucial moments of the game. For example, {rule_text[:30]}... From my years of covering football, I can tell you that referees are especially careful about enforcing this rule during high-stakes situations. | |
| Q3: What are some common misconceptions about this rule? | |
| A3: Many fans think this rule is more complicated than it actually is. The key thing to remember is that {rule_text[:40]}... Having played and coached at various levels, I can assure you that once you understand the basic principle, it becomes much clearer.""", | |
| f"""Q1: How do referees typically enforce this rule? | |
| A1: Referees are trained to look for specific indicators when applying this rule. Since {rule_text[:50]}..., they need to make quick decisions based on what they observe. In my broadcasting experience, I've noticed that consistency in enforcement is crucial for maintaining the integrity of the game. | |
| Q2: Has this rule changed over the years? | |
| A2: Like many NFL rules, this one has evolved to improve player safety and game flow. The current version states that {rule_text[:40]}... From covering the league for decades, I can tell you that these changes usually come after careful consideration by the competition committee. | |
| Q3: What should coaches teach players about this rule? | |
| A3: Coaches need to emphasize the practical implications of this rule during practice. Since {rule_text[:35]}..., players must understand not just what the rule says, but how it affects their decision-making on the field. This is fundamental knowledge that every player should master.""" | |
| ] | |
| # Add some delay to simulate API call | |
| time.sleep(0.5) | |
| # Return a random mock response | |
| return random.choice(mock_responses) | |
| def call_hf_model(self, prompt: str, max_retries: int = 3) -> str: | |
| """Call our Hugging Face Gradio interface with retry logic""" | |
| # MOCK MODE - Remove this when HF space is working | |
| if True: # Change to False when space is working | |
| return self.generate_mock_response(prompt) | |
| # Use the Gradio interface endpoint | |
| gradio_url = f"{self.api_base_url}/api/predict" | |
| # Gradio payload format for our chat interface | |
| payload = { | |
| "data": [ | |
| prompt, # message | |
| [], # history (empty for new conversation) | |
| 0.8, # temperature | |
| False, # json_mode | |
| "general" # json_template | |
| ], | |
| "fn_index": 0 # Function index for the respond function | |
| } | |
| for attempt in range(max_retries): | |
| try: | |
| # Add delay between requests to be respectful | |
| if attempt > 0: | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| response = self.session.post( | |
| gradio_url, | |
| json=payload, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Gradio returns data in format: {"data": [history, ""]} | |
| if 'data' in data and len(data['data']) > 0: | |
| history = data['data'][0] | |
| if history and len(history) > 0: | |
| # Get the last assistant response | |
| last_response = history[-1] | |
| if isinstance(last_response, dict) and 'content' in last_response: | |
| return last_response['content'] | |
| elif isinstance(last_response, list) and len(last_response) > 1: | |
| return last_response[1] # [user_msg, assistant_msg] format | |
| # Fallback: return raw data as string | |
| return str(data) | |
| else: | |
| logger.warning(f"Gradio API call failed with status {response.status_code}") | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f"Request failed (attempt {attempt + 1}): {str(e)}") | |
| if attempt == max_retries - 1: | |
| raise | |
| return "" | |
| def parse_qa_response(self, response: str, original_rule: str) -> List[Dict[str, Any]]: | |
| """Parse the model response to extract Q&A pairs""" | |
| prompts = [] | |
| try: | |
| lines = response.strip().split('\n') | |
| current_q = None | |
| current_a = None | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Look for question patterns | |
| if line.startswith(('Q1:', 'Q2:', 'Q3:', '1.', '2.', '3.')): | |
| if current_q and current_a: | |
| # Save previous Q&A pair | |
| prompts.append(self.create_training_example(current_q, current_a)) | |
| # Extract question | |
| current_q = line.split(':', 1)[1].strip() if ':' in line else line | |
| current_a = None | |
| # Look for answer patterns | |
| elif line.startswith(('A1:', 'A2:', 'A3:')): | |
| current_a = line.split(':', 1)[1].strip() if ':' in line else line | |
| # Continue building the answer if we're in answer mode | |
| elif current_q and current_a is not None: | |
| current_a += ' ' + line | |
| elif current_q and not current_a: | |
| # This might be a continuation of the question or start of answer | |
| if len(line) > 50: # Likely an answer | |
| current_a = line | |
| else: | |
| current_q += ' ' + line | |
| # Don't forget the last Q&A pair | |
| if current_q and current_a: | |
| prompts.append(self.create_training_example(current_q, current_a)) | |
| except Exception as e: | |
| logger.error(f"Error parsing response: {str(e)}") | |
| # Fallback: create a generic example | |
| prompts.append(self.create_training_example( | |
| f"Can you explain this NFL rule?", | |
| f"This rule states: {original_rule[:200]}..." | |
| )) | |
| return prompts | |
| def create_training_example(self, user_question: str, assistant_response: str) -> Dict[str, Any]: | |
| """Create a properly formatted training example""" | |
| return { | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": SYSTEM_MESSAGE | |
| }, | |
| { | |
| "role": "user", | |
| "content": user_question.strip() | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": assistant_response.strip() | |
| } | |
| ] | |
| } | |
| def process_rules(self, rules: List[Dict[str, str]], sample_size: int = None) -> List[Dict[str, Any]]: | |
| """Process all rules or a sample to generate training data""" | |
| if sample_size: | |
| rules = random.sample(rules, min(sample_size, len(rules))) | |
| logger.info(f"Processing random sample of {len(rules)} rules") | |
| else: | |
| logger.info(f"Processing all {len(rules)} rules") | |
| all_training_examples = [] | |
| for i, rule in enumerate(rules, 1): | |
| # Get rule text from CSV (adjust column name as needed) | |
| rule_text = rule.get('rule_text', rule.get('description', rule.get('text', str(rule)))) | |
| rule_number = rule.get('rule_number', rule.get('number', f"Rule_{i}")) | |
| logger.info(f"Processing rule {i}/{len(rules)}: {rule_number}") | |
| # Generate prompts for this rule | |
| prompts = self.generate_prompts_for_rule(rule_text, rule_number) | |
| all_training_examples.extend(prompts) | |
| self.stats['rules_processed'] += 1 | |
| # Add a small delay to be respectful to the API | |
| time.sleep(1) | |
| # Progress update every 10 rules | |
| if i % 10 == 0: | |
| logger.info(f"Progress: {i}/{len(rules)} rules processed, {len(all_training_examples)} examples generated") | |
| return all_training_examples | |
| def save_jsonl(self, training_examples: List[Dict[str, Any]], filename: str = None): | |
| """Save training examples to JSONL file""" | |
| if not filename: | |
| timestamp = int(time.time()) | |
| filename = f"nfl_training_data_{timestamp}.jsonl" | |
| output_path = self.output_dir / filename | |
| try: | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| for example in training_examples: | |
| f.write(json.dumps(example, ensure_ascii=False) + '\n') | |
| logger.info(f"Saved {len(training_examples)} training examples to {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Error saving JSONL file: {str(e)}") | |
| raise | |
| def print_stats(self): | |
| """Print generation statistics""" | |
| print("\n" + "="*50) | |
| print("GENERATION STATISTICS") | |
| print("="*50) | |
| print(f"Rules processed: {self.stats['rules_processed']}") | |
| print(f"Total prompts generated: {self.stats['prompts_generated']}") | |
| print(f"API calls made: {self.stats['api_calls_made']}") | |
| print(f"Errors encountered: {self.stats['errors']}") | |
| print(f"Average prompts per rule: {self.stats['prompts_generated'] / max(1, self.stats['rules_processed']):.1f}") | |
| print("="*50) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Generate NFL training data from rulebook CSV') | |
| parser.add_argument('csv_file', help='Path to the 2024 NFL rulebook CSV file') | |
| # Add mutually exclusive group for processing options | |
| processing_group = parser.add_mutually_exclusive_group() | |
| processing_group.add_argument('--sample', type=int, default=None, | |
| help='Process only a random sample of N rules') | |
| processing_group.add_argument('--random-10', action='store_true', | |
| help='Process 10 random rules (quick test)') | |
| processing_group.add_argument('--full', action='store_true', | |
| help='Process all rules in the file') | |
| parser.add_argument('--output-dir', default='output', | |
| help='Output directory for generated files') | |
| parser.add_argument('--output-file', default=None, | |
| help='Output JSONL filename (default: auto-generated)') | |
| args = parser.parse_args() | |
| # Handle the processing options | |
| sample_size = None | |
| if args.random_10: | |
| sample_size = 10 | |
| print("π― Running with 10 random rules for testing") | |
| elif args.sample: | |
| sample_size = args.sample | |
| print(f"π― Running with {sample_size} random rules") | |
| elif args.full: | |
| sample_size = None | |
| print("π― Running with ALL rules in the file") | |
| else: | |
| # Default behavior - ask user | |
| print("\nπ NFL Training Data Generator") | |
| print("Choose processing mode:") | |
| print("1. Test with 10 random rules (recommended for first run)") | |
| print("2. Process ALL rules in the file") | |
| while True: | |
| choice = input("\nEnter your choice (1 or 2): ").strip() | |
| if choice == "1": | |
| sample_size = 10 | |
| print("π― Processing 10 random rules...") | |
| break | |
| elif choice == "2": | |
| sample_size = None | |
| print("π― Processing ALL rules...") | |
| break | |
| else: | |
| print("β Please enter 1 or 2") | |
| # Update args with the determined sample size | |
| args.sample = sample_size | |
| # Validate CSV file exists | |
| if not Path(args.csv_file).exists(): | |
| print(f"Error: CSV file not found: {args.csv_file}") | |
| return 1 | |
| # Create generator | |
| generator = NFLTrainingDataGenerator(args.csv_file, args.output_dir) | |
| try: | |
| # Load rules | |
| rules = generator.load_rulebook_csv() | |
| # Process rules | |
| training_examples = generator.process_rules(rules, args.sample) | |
| if not training_examples: | |
| print("No training examples generated!") | |
| return 1 | |
| # Save to JSONL | |
| output_file = generator.save_jsonl(training_examples, args.output_file) | |
| # Print statistics | |
| generator.print_stats() | |
| print(f"\nβ Successfully generated training data!") | |
| print(f"π Output file: {output_file}") | |
| print(f"π Total examples: {len(training_examples)}") | |
| # Show a sample example | |
| if training_examples: | |
| print(f"\nπ Sample training example:") | |
| print(json.dumps(training_examples[0], indent=2, ensure_ascii=False)) | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Fatal error: {str(e)}") | |
| return 1 | |
| if __name__ == "__main__": | |
| exit(main()) |