File size: 5,792 Bytes
1dd0e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#include "DocumentRetriever.h"
#include <iostream>
#include <algorithm>
#include <cctype>
#include <locale>

namespace hhb {
namespace rag {

DocumentRetriever::DocumentRetriever() : initialized_(false) {
}

DocumentRetriever::~DocumentRetriever() {
}

bool DocumentRetriever::initialize() {
    chunks_.clear();
    
    // 加载 README.md 文档
    addDocument("README.md", "README.md");
    
    // 加载其他可能的文档
    // addDocument("docs/manual.md", "Manual");
    
    initialized_ = true;
    std::cout << "DocumentRetriever initialized with " << chunks_.size() << " chunks" << std::endl;
    return true;
}

bool DocumentRetriever::addDocument(const std::string& filePath, const std::string& source) {
    std::ifstream file(filePath);
    if (!file.is_open()) {
        std::cout << "Warning: Could not open document: " << filePath << std::endl;
        return false;
    }
    
    std::stringstream buffer;
    buffer << file.rdbuf();
    std::string content = buffer.str();
    
    splitIntoChunks(content, source);
    
    file.close();
    return true;
}

void DocumentRetriever::splitIntoChunks(const std::string& content, const std::string& source) {
    std::istringstream stream(content);
    std::string line;
    std::string currentChunk;
    int lineNumber = 0;
    int chunkSize = 0;
    const int maxChunkLines = 10;
    
    while (std::getline(stream, line)) {
        lineNumber++;
        currentChunk += line + "\n";
        chunkSize++;
        
        // 当块达到一定行数或遇到空行时,保存块
        if (chunkSize >= maxChunkLines || line.empty()) {
            if (!currentChunk.empty() && chunkSize > 2) {
                DocumentChunk chunk;
                chunk.content = currentChunk;
                chunk.source = source;
                chunk.line_number = lineNumber - chunkSize;
                chunks_.push_back(chunk);
            }
            currentChunk.clear();
            chunkSize = 0;
        }
    }
    
    // 保存最后一个块
    if (!currentChunk.empty() && chunkSize > 0) {
        DocumentChunk chunk;
        chunk.content = currentChunk;
        chunk.source = source;
        chunk.line_number = lineNumber - chunkSize;
        chunks_.push_back(chunk);
    }
}

std::vector<std::string> DocumentRetriever::tokenize(const std::string& text) const {
    std::vector<std::string> tokens;
    std::string token;
    
    for (char c : text) {
        if (std::isalnum(static_cast<unsigned char>(c))) {
            token += std::tolower(static_cast<unsigned char>(c));
        } else if (!token.empty()) {
            tokens.push_back(token);
            token.clear();
        }
    }
    
    if (!token.empty()) {
        tokens.push_back(token);
    }
    
    return tokens;
}

double DocumentRetriever::calculateSimilarity(const std::vector<std::string>& queryTokens,
                                             const std::vector<std::string>& chunkTokens) const {
    if (queryTokens.empty() || chunkTokens.empty()) {
        return 0.0;
    }
    
    // 计算查询词在块中出现的次数
    int matchCount = 0;
    for (const auto& queryToken : queryTokens) {
        for (const auto& chunkToken : chunkTokens) {
            if (chunkToken.find(queryToken) != std::string::npos ||
                queryToken.find(chunkToken) != std::string::npos) {
                matchCount++;
                break;
            }
        }
    }
    
    // 返回相似度分数
    return static_cast<double>(matchCount) / static_cast<double>(queryTokens.size());
}

bool DocumentRetriever::needsDocumentRetrieval(const std::string& query) const {
    // 定义软件操作相关的关键词
    std::vector<std::string> softwareKeywords = {
        "如何", "怎么", "导出", "导入", "打开", "保存", "加载",
        "操作", "使用", "功能", "设置", "配置", "快捷键",
        "导出stl", "导出obj", "导入stl", "加载模型",
        "软件", "系统", "界面", "菜单", "按钮",
        "帮助", "教程", "文档", "说明"
    };
    
    std::string lowerQuery = query;
    std::transform(lowerQuery.begin(), lowerQuery.end(), lowerQuery.begin(),
                   [](unsigned char c){ return std::tolower(c); });
    
    for (const auto& keyword : softwareKeywords) {
        if (lowerQuery.find(keyword) != std::string::npos) {
            return true;
        }
    }
    
    return false;
}

std::string DocumentRetriever::retrieve(const std::string& query, int maxChunks) const {
    if (chunks_.empty()) {
        return "";
    }
    
    // 对查询进行分词
    std::vector<std::string> queryTokens = tokenize(query);
    if (queryTokens.empty()) {
        return "";
    }
    
    // 计算每个块与查询的相似度
    std::vector<std::pair<double, size_t>> scores;
    for (size_t i = 0; i < chunks_.size(); ++i) {
        std::vector<std::string> chunkTokens = tokenize(chunks_[i].content);
        double similarity = calculateSimilarity(queryTokens, chunkTokens);
        scores.push_back({similarity, i});
    }
    
    // 按相似度排序
    std::sort(scores.begin(), scores.end(),
              [](const auto& a, const auto& b) { return a.first > b.first; });
    
    // 构建检索结果
    std::stringstream result;
    result << "\n\n[检索到的相关文档片段]\n";
    
    int count = 0;
    for (const auto& score : scores) {
        if (count >= maxChunks || score.first <= 0.1) {
            break;
        }
        
        const auto& chunk = chunks_[score.second];
        result << "\n--- 来源: " << chunk.source << " (行 " << chunk.line_number << ") ---\n";
        result << chunk.content << "\n";
        count++;
    }
    
    return result.str();
}

} // namespace rag
} // namespace hhb