File size: 5,140 Bytes
71a764a
 
 
 
 
36d5e94
 
 
71a764a
57be184
 
71a764a
57be184
36d5e94
 
 
 
57be184
36d5e94
 
 
 
 
 
71a764a
 
 
 
 
36d5e94
 
57be184
 
 
 
 
71a764a
 
 
 
 
 
 
36d5e94
 
 
57be184
36d5e94
 
 
 
 
 
 
71a764a
57be184
 
 
36d5e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71a764a
 
 
 
 
 
 
 
 
 
36d5e94
71a764a
 
 
 
 
 
 
 
 
 
36d5e94
71a764a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Data loading and saving functionality"""

import json
import random
import datetime
import uuid
from pathlib import Path
from huggingface_hub import CommitScheduler

from src.config.settings import HF_RESULTS_REPO, HF_PROMPTS_REPO
from src.utils.hf_data_manager import HFDataManager

JSON_DATASET_DIR = Path("testing/data/results")
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
JSON_DATASET_PATH = JSON_DATASET_DIR / f"results_{uuid.uuid4()}.json"

scheduler = CommitScheduler(
    repo_id=HF_RESULTS_REPO,
    repo_type="dataset",
    folder_path=JSON_DATASET_DIR.as_posix(),
    path_in_repo="data",
    every=10
)

class DataManager:
    """Manages loading and saving of prompts and results data"""
    
    def __init__(self):
        self.prompts_data = []
        self.results = None

    def load_prompts_data(self):
        """Load prompts data"""
        self.prompts_data = self.load_from_hf(HF_PROMPTS_REPO)
        if not self.prompts_data:
            raise RuntimeError("No prompts data loaded from Hugging Face.")
    
    def get_random_prompt(self):
        """Get a random prompt from loaded data"""
        if not self.prompts_data:
            raise RuntimeError("No prompts data loaded. Call load_prompts_data() first.")
        return random.choice(self.prompts_data)
    
    def get_results(self):
        """Get all results data, loading if not already loaded."""
        if self.results is None:
            self.results = self.load_from_hf(HF_RESULTS_REPO)
        return self.results
    
    def add_results(self, new_results):
        """Add new results to the existing results list."""
        if self.results is None:
            raise RuntimeError("Results not loaded. Call get_results() first.")
        self.results.extend(new_results)
    
    def load_from_hf(self, hf_repo):
        """Load data from Hugging Face dataset repository."""
        return HFDataManager.load_from_hf(hf_repo)
    
    def save_interaction_to_hf(self, prompt_data, user_continuation, generated_response, 
                               cosine_distance, session_id, num_user_tokens):
        interaction = {
            "prompt": prompt_data["prompt"],
            "model": prompt_data["model"],
            "llm_partial_response": prompt_data["llm_partial_response"],
            "llm_full_response_original": prompt_data["llm_full_response_original"],
            "user_continuation": user_continuation,
            "full_response_from_user": generated_response,
            "cosine_distance": cosine_distance,
            "timestamp": datetime.datetime.now().isoformat(),
            "continuation_source": session_id,
            "num_user_tokens": num_user_tokens,
            "continuation_prompt": "",
            "full_continuation_prompt": ""
        }

        self.add_results([interaction])

        with scheduler.lock:
            with open(JSON_DATASET_PATH, "a") as f:
                f.write(json.dumps(interaction) + "\n")

    def filter_results_by_partial_response(self, results, prompt, partial_response):
        """Filter results to only include entries for the current prompt."""
        return [r for r in results if r["prompt"] == prompt and r["llm_partial_response"] == partial_response]
    
    def filter_results_by_session(self, results, session_id):
        """Filter results to only include entries from the specified session."""
        return [r for r in results if r.get("continuation_source") == session_id]
    
    def get_gallery_responses(self, min_score=0.3, limit=20):
        """Get gallery responses with minimum creativity score"""
        all_results = self.get_results()
        
        # Filter by minimum score and sort by score (descending)
        filtered_results = [r for r in all_results if r["cosine_distance"] >= min_score]
        filtered_results.sort(key=lambda x: x["cosine_distance"], reverse=True)
        
        # Return top results
        return filtered_results[:limit]
    
    def get_inspire_me_examples(self, prompt, partial_response, limit=5):
        """Get inspiring examples for the current prompt"""
        all_results = self.get_results()
        
        # Filter to current prompt and get good examples (≥0.2 score)
        examples = [r for r in all_results 
                   if r["prompt"] == prompt 
                   and r["llm_partial_response"] == partial_response 
                   and r["cosine_distance"] >= 0.2]
        
        # Sort by creativity score and return random sample
        examples.sort(key=lambda x: x["cosine_distance"], reverse=True)
        
        # Get diverse examples (not just the top ones)
        if len(examples) > limit:
            # Take some from top, some from middle range
            top_examples = examples[:min(3, len(examples))]
            remaining = examples[3:]
            if remaining:
                additional = random.sample(remaining, min(limit-len(top_examples), len(remaining)))
                examples = top_examples + additional
            else:
                examples = top_examples
        
        return random.sample(examples, min(limit, len(examples)))