File size: 2,329 Bytes
9c062cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pandas as pd
import re
from synthesis_qa_backend import ResearchSynthesizer
from config import API_KEY, INDEX_PATH, METADATA_PATH, SPECIFIC_COUNTRIES

class DataHandler:
    def __init__(self):
        self.synthesizer = None
        self.docs_df = pd.DataFrame()
        self.countries_list = []
        self.sectors_list = []
        self.load_data()
    
    def load_data(self):
        """Initialize the research system and load data"""
        try:
            self.synthesizer = ResearchSynthesizer(INDEX_PATH, METADATA_PATH, API_KEY)
            metadata_df = pd.read_csv(METADATA_PATH)
            self.docs_df = metadata_df.drop_duplicates(subset=['record_id'])
            print(f"✅ Loaded {len(self.docs_df)} unique documents")
            
            # Get unique values for dropdowns
            self.countries_list, self.sectors_list = self._get_unique_values()
            
        except Exception as e:
            print(f"❌ Error loading system: {e}")
            self.synthesizer = None
            self.docs_df = pd.DataFrame()
    
    def _get_unique_values(self):
        """Get unique values for dropdowns"""
        if self.docs_df.empty:
            return [], []
        
        countries_list = []
        sectors_list = []
        
        if 'study_countries' in self.docs_df.columns:
            for countries_str in self.docs_df['study_countries'].dropna():
                if pd.isna(countries_str) or str(countries_str).lower() in ['nan', 'none', '']:
                    continue
                countries = [c.strip() for c in str(countries_str).replace(';', ',').split(',')]
                filtered = [c for c in countries if c in SPECIFIC_COUNTRIES and len(c) > 1]
                countries_list.extend(filtered)
            
            countries_list = sorted(list(set(countries_list)))
        
        if 'world_bank_sector' in self.docs_df.columns:
            sectors_list = sorted(self.docs_df['world_bank_sector'].dropna().unique().tolist())
        
        return countries_list, sectors_list
    
    def get_data(self):
        """Return all data objects"""
        return {
            'synthesizer': self.synthesizer,
            'docs_df': self.docs_df,
            'countries_list': self.countries_list,
            'sectors_list': self.sectors_list
        }