File size: 6,910 Bytes
1ea875f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# -*- coding: utf-8 -*-
"""
Session 工具模块

提供基于仓库 URL 的 Session ID 生成和管理
"""

import hashlib
import re
from typing import Optional, Tuple, Dict
from urllib.parse import urlparse

from app.core.config import conversation_config


def normalize_repo_url(url: str) -> str:
    """
    标准化 GitHub 仓库 URL
    
    支持格式:
    - https://github.com/owner/repo
    - https://github.com/owner/repo.git
    - https://github.com/owner/repo/tree/main
    - git@github.com:owner/repo.git
    
    Returns:
        标准化的 URL: https://github.com/owner/repo (全小写)
    """
    url = url.strip().lower()  # 统一转为小写
    
    # 处理 SSH 格式
    if url.startswith('git@'):
        # git@github.com:owner/repo.git -> https://github.com/owner/repo
        match = re.match(r'git@github\.com:(.+?)(?:\.git)?$', url)
        if match:
            return f"https://github.com/{match.group(1)}"
    
    # 处理 HTTPS 格式
    parsed = urlparse(url)
    path = parsed.path.strip('/')
    
    # 移除 .git 后缀
    if path.endswith('.git'):
        path = path[:-4]
    
    # 只保留 owner/repo 部分
    parts = path.split('/')
    if len(parts) >= 2:
        path = f"{parts[0]}/{parts[1]}"
    
    return f"https://github.com/{path}"


def extract_repo_info(url: str) -> Tuple[str, str]:
    """
    从 URL 提取仓库信息
    
    Returns:
        (owner, repo) 元组
    """
    normalized = normalize_repo_url(url)
    path = urlparse(normalized).path.strip('/')
    parts = path.split('/')
    
    if len(parts) >= 2:
        return parts[0], parts[1]
    return "", ""


def generate_repo_session_id(repo_url: str) -> str:
    """
    基于仓库 URL 生成稳定的 Session ID
    
    同一仓库 URL -> 同一 Session ID
    
    格式: repo_{short_hash}_{owner}_{repo}
    """
    normalized = normalize_repo_url(repo_url)
    owner, repo = extract_repo_info(repo_url)
    
    # 生成短 hash (8 字符)
    url_hash = hashlib.sha256(normalized.encode()).hexdigest()[:8]
    
    # 清理 owner 和 repo 名称
    clean_owner = re.sub(r'[^a-zA-Z0-9]', '', owner)[:10]
    clean_repo = re.sub(r'[^a-zA-Z0-9]', '', repo)[:15]
    
    return f"repo_{url_hash}_{clean_owner}_{clean_repo}"


def is_repo_session_id(session_id: str) -> bool:
    """判断是否为仓库级 Session ID"""
    return session_id.startswith("repo_")


# === 对话历史管理 ===

class ConversationMemory:
    """
    对话记忆管理 - 滑动窗口 + 摘要压缩
    
    特性:
    1. 保留最近 N 轮完整对话
    2. 早期对话自动压缩为摘要
    3. 支持 token 估算
    """
    
    def __init__(
        self,
        max_recent_turns: int = None,
        max_context_tokens: int = None,
        summary_threshold: int = None,
    ):
        # 使用统一配置
        self.max_recent_turns = max_recent_turns or conversation_config.max_recent_turns
        self.max_context_tokens = max_context_tokens or conversation_config.max_context_tokens
        self.summary_threshold = summary_threshold or conversation_config.summary_threshold
        
        self._messages: list = []            # 完整消息历史
        self._summary: Optional[str] = None  # 早期对话摘要
        self._summary_up_to: int = 0         # 摘要覆盖到第 N 条消息
    
    def add_message(self, role: str, content: str) -> None:
        """添加消息"""
        self._messages.append({
            "role": role,
            "content": content
        })
    
    def add_user_message(self, content: str) -> None:
        """添加用户消息"""
        self.add_message("user", content)
    
    def add_assistant_message(self, content: str) -> None:
        """添加助手消息"""
        self.add_message("assistant", content)
    
    def get_context_messages(self) -> list:
        """
        获取用于 LLM 的上下文消息
        
        策略:
        1. 如果消息数 <= max_recent_turns * 2,返回全部
        2. 否则返回: [摘要] + 最近 N 轮
        """
        total_messages = len(self._messages)
        max_messages = self.max_recent_turns * 2  # user + assistant = 1 轮
        
        if total_messages <= max_messages:
            return list(self._messages)
        
        # 需要截断
        recent_messages = self._messages[-max_messages:]
        
        # 如果有摘要,加在前面
        if self._summary:
            return [
                {"role": "system", "content": f"[Earlier conversation summary]\n{self._summary}"}
            ] + recent_messages
        
        return recent_messages
    
    def needs_summarization(self) -> bool:
        """检查是否需要生成摘要"""
        unsummarized = len(self._messages) - self._summary_up_to
        return unsummarized > self.summary_threshold * 2
    
    def get_messages_to_summarize(self) -> list:
        """获取需要摘要的消息"""
        if not self.needs_summarization():
            return []
        
        # 保留最近的,摘要早期的
        end_idx = len(self._messages) - self.max_recent_turns * 2
        return self._messages[self._summary_up_to:end_idx]
    
    def set_summary(self, summary: str, up_to_index: int) -> None:
        """设置摘要"""
        if self._summary:
            # 合并旧摘要
            self._summary = f"{self._summary}\n\n{summary}"
        else:
            self._summary = summary
        self._summary_up_to = up_to_index
    
    def clear(self) -> None:
        """清空对话历史"""
        self._messages = []
        self._summary = None
        self._summary_up_to = 0
    
    def get_turn_count(self) -> int:
        """获取对话轮数"""
        return len(self._messages) // 2
    
    def get_stats(self) -> dict:
        """获取统计信息"""
        return {
            "total_messages": len(self._messages),
            "turn_count": self.get_turn_count(),
            "has_summary": self._summary is not None,
            "summary_covers": self._summary_up_to,
        }


# === 全局对话记忆存储 ===
# key: session_id, value: ConversationMemory
# 纯内存存储,服务重启自动清空
_conversation_memories: Dict[str, ConversationMemory] = {}


def get_conversation_memory(session_id: str) -> ConversationMemory:
    """获取或创建对话记忆"""
    if session_id not in _conversation_memories:
        _conversation_memories[session_id] = ConversationMemory()
    return _conversation_memories[session_id]


def clear_conversation_memory(session_id: str) -> None:
    """清除对话记忆"""
    if session_id in _conversation_memories:
        del _conversation_memories[session_id]


def get_memory_stats() -> dict:
    """获取对话记忆统计"""
    return {
        "total_memories": len(_conversation_memories),
        "sessions": list(_conversation_memories.keys()),
    }