File size: 6,223 Bytes
e189a31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""OpenAI-compatible LLM summarizer for news items."""

import json
import logging
import os
import time
from typing import Dict, List, Optional, Tuple

import requests

logger = logging.getLogger(__name__)


class OpenAICompatSummarizer:
    """
    Summarize news items using an OpenAI-compatible chat completions API.
    """

    def __init__(
        self,
        api_base: Optional[str] = None,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        timeout: Optional[int] = None,
        max_items_per_request: Optional[int] = None,
        max_chars_per_item: Optional[int] = None,
        max_chars_total: Optional[int] = None,
    ):
        self.api_base = (api_base or os.getenv("LLM_API_BASE") or "https://researchengineering-agi.hf.space").rstrip("/")
        self.api_key = api_key if api_key is not None else os.getenv("LLM_API_KEY", "")
        self.model = model or os.getenv("LLM_MODEL", "gpt-4o-mini")
        self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "600"))
        # Conservative defaults to avoid large token bursts on slow servers.
        self.max_items_per_request = max_items_per_request or int(os.getenv("LLM_SUMMARY_BATCH", "2"))
        self.max_chars_per_item = max_chars_per_item or int(os.getenv("LLM_SUMMARY_MAX_CHARS", "600"))
        self.max_chars_total = max_chars_total or int(os.getenv("LLM_SUMMARY_MAX_CHARS_TOTAL", "1200"))
        self.enabled = os.getenv("ENABLE_AI_SUMMARIZATION", "true").lower() in {"1", "true", "yes"}
        self.sleep_seconds = float(os.getenv("LLM_SUMMARY_SLEEP_SECONDS", "0"))

        self._chat_url = f"{self.api_base}/v1/chat/completions"

    def summarize_items(self, items: List[Dict], source: Optional[str] = None) -> List[Dict]:
        if not self.enabled or not items:
            return items

        candidates: List[Tuple[Dict, str]] = []
        for item in items:
            if str(item.get("summary_ai", "")).strip():
                continue
            text = self._build_input_text(item)
            if text:
                candidates.append((item, text))

        if not candidates:
            return items

        chunks = self._chunked(candidates, self.max_items_per_request)
        for idx, chunk in enumerate(chunks, start=1):
            texts = [text for _, text in chunk]
            if self.max_chars_total > 0:
                texts = self._truncate_to_total(texts, self.max_chars_total)
            summaries = self._summarize_chunk(texts, source=source)
            if not summaries:
                continue
            for (item, _), summary in zip(chunk, summaries):
                if summary:
                    item["summary_ai"] = summary
                    item["summary"] = summary
            if self.sleep_seconds > 0 and idx < len(chunks):
                time.sleep(self.sleep_seconds)

        return items

    def _build_input_text(self, item: Dict) -> str:
        title = str(item.get("title", "")).strip()
        if title:
            source = str(item.get("source", "")).strip()
            if len(title) > self.max_chars_per_item:
                title = title[: self.max_chars_per_item].rstrip()
            if source:
                return f"Source: {source}\nTitle: {title}"
            return f"Title: {title}"
        return ""

    def _summarize_chunk(self, texts: List[str], source: Optional[str] = None) -> List[str]:
        system_prompt = (
            "You are a financial news summarizer. "
            "Return concise, factual summaries in 1-2 sentences, <=240 characters each. "
            "Do not add speculation or new facts."
        )
        source_line = f"Source: {source}" if source else ""

        items_text = []
        for idx, text in enumerate(texts, start=1):
            items_text.append(f"{idx}. {text}")

        user_prompt = (
            "Summarize each item below. "
            "Return a JSON array of strings in the same order. "
            "No extra text.\n"
            f"{source_line}\n\n" + "\n\n".join(items_text)
        )

        payload = {
            "model": self.model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "temperature": 0.2,
        }

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        try:
            response = requests.post(self._chat_url, json=payload, headers=headers, timeout=self.timeout)
            response.raise_for_status()
            data = response.json()
            content = (
                data.get("choices", [{}])[0]
                .get("message", {})
                .get("content", "")
                .strip()
            )
            summaries = self._parse_json_array(content)
            if summaries and len(summaries) == len(texts):
                return summaries
            logger.warning("LLM summarizer returned unexpected format or length")
            return []
        except Exception as exc:
            logger.warning(f"LLM summarization failed: {exc}")
            return []

    def _parse_json_array(self, content: str) -> List[str]:
        if not content:
            return []
        try:
            parsed = json.loads(content)
            if isinstance(parsed, list):
                return [str(x).strip() for x in parsed]
            return []
        except Exception:
            return []

    def _chunked(self, items: List[Tuple[Dict, str]], size: int) -> List[List[Tuple[Dict, str]]]:
        if size <= 0:
            return [items]
        return [items[i : i + size] for i in range(0, len(items), size)]

    def _truncate_to_total(self, texts: List[str], max_total: int) -> List[str]:
        if max_total <= 0:
            return texts
        truncated = []
        total = 0
        for text in texts:
            if total >= max_total:
                break
            remaining = max_total - total
            if len(text) > remaining:
                text = text[:remaining].rstrip()
            truncated.append(text)
            total += len(text)
        return truncated