vedaco commited on
Commit
3fc35d3
·
verified ·
1 Parent(s): f70ee56

Create data_collector.py

Browse files
Files changed (1) hide show
  1. data_collector.py +129 -0
data_collector.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data collector for continuous learning"""
2
+
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+ from typing import Optional, Dict, List
7
+ import hashlib
8
+
9
+ from database import db
10
+ from config import (
11
+ DATA_DIR, LEARNING_FROM_FEEDBACK,
12
+ SAVE_ALL_INTERACTIONS, REQUIRE_APPROVAL
13
+ )
14
+
15
+ class DataCollector:
16
+ """Collects and manages user interaction data for continuous learning"""
17
+
18
+ def __init__(self):
19
+ self.current_session_id = self._generate_session_id()
20
+ self.session_interactions = []
21
+
22
+ def _generate_session_id(self) -> str:
23
+ """Generate unique session ID"""
24
+ timestamp = datetime.now().isoformat()
25
+ return hashlib.md5(timestamp.encode()).hexdigest()[:12]
26
+
27
+ def collect_interaction(
28
+ self,
29
+ prompt: str,
30
+ generated_code: str,
31
+ temperature: float = 0.7,
32
+ max_tokens: int = 100
33
+ ) -> int:
34
+ """Collect a user interaction"""
35
+
36
+ if not SAVE_ALL_INTERACTIONS:
37
+ return -1
38
+
39
+ # Save to database
40
+ interaction_id = db.save_interaction(
41
+ prompt=prompt,
42
+ generated_code=generated_code,
43
+ temperature=temperature,
44
+ max_tokens=max_tokens,
45
+ session_id=self.current_session_id
46
+ )
47
+
48
+ # Track in session
49
+ self.session_interactions.append({
50
+ 'id': interaction_id,
51
+ 'prompt': prompt,
52
+ 'code': generated_code,
53
+ 'timestamp': datetime.now().isoformat()
54
+ })
55
+
56
+ return interaction_id
57
+
58
+ def record_feedback(
59
+ self,
60
+ interaction_id: int,
61
+ is_positive: bool,
62
+ edited_code: str = None
63
+ ):
64
+ """Record user feedback for an interaction"""
65
+
66
+ if not LEARNING_FROM_FEEDBACK:
67
+ return
68
+
69
+ feedback = 1 if is_positive else -1
70
+ db.update_feedback(interaction_id, feedback, edited_code)
71
+
72
+ print(f"Feedback recorded: {'👍' if is_positive else '👎'} for interaction {interaction_id}")
73
+
74
+ def add_training_sample(self, code: str, category: str = "user_contributed"):
75
+ """Add a code sample directly to training data"""
76
+ return db.add_code_sample(code, source="user", category=category)
77
+
78
+ def get_training_data(self, include_base: bool = True) -> List[str]:
79
+ """Get all available training data"""
80
+ samples = []
81
+
82
+ # Get approved user interactions
83
+ approved = db.get_approved_samples()
84
+ for item in approved:
85
+ # Combine prompt and code for training
86
+ sample = f"# Prompt: {item['prompt']}\n{item['code']}"
87
+ samples.append(sample)
88
+
89
+ # Get curated code samples
90
+ code_samples = db.get_all_code_samples()
91
+ for item in code_samples:
92
+ samples.append(item['code'])
93
+
94
+ # Include base training data
95
+ if include_base:
96
+ base_path = os.path.join(DATA_DIR, "..", "programming.txt")
97
+ if os.path.exists(base_path):
98
+ with open(base_path, 'r', encoding='utf-8') as f:
99
+ base_code = f.read()
100
+ samples.append(base_code)
101
+
102
+ return samples
103
+
104
+ def get_new_training_data(self) -> List[Dict]:
105
+ """Get new approved samples not yet used for training"""
106
+ return db.get_approved_samples(not_used=True)
107
+
108
+ def get_pending_count(self) -> int:
109
+ """Get count of samples pending training"""
110
+ return db.get_pending_samples_count()
111
+
112
+ def get_statistics(self) -> Dict:
113
+ """Get collection statistics"""
114
+ stats = db.get_statistics()
115
+ stats['session_interactions'] = len(self.session_interactions)
116
+ return stats
117
+
118
+ def export_training_data(self, filepath: str):
119
+ """Export all training data to a file"""
120
+ samples = self.get_training_data()
121
+
122
+ with open(filepath, 'w', encoding='utf-8') as f:
123
+ f.write('\n\n'.join(samples))
124
+
125
+ print(f"Exported {len(samples)} samples to {filepath}")
126
+
127
+
128
+ # Global collector instance
129
+ collector = DataCollector()