xianqiu commited on
Commit
1d5d14e
·
1 Parent(s): 438e046

Improve down-prediction accuracy: 44% -> 63%, overall 55% -> 62%

Browse files

Enhanced training v4:
- Weighted sampling (1.5x for down samples)
- Imbalanced dataset (55% down / 45% up)
- 8 epochs with lower LR (3e-6)
- Converged after 2 iterations

Metrics:
- Direction accuracy: 55.1% -> 62.3%
- Down-prediction accuracy: 44.2% -> 62.6%
- Up-prediction accuracy: 65.9% -> 62.0% (slight decrease, better balanced)

README.md CHANGED
@@ -1,14 +1,3 @@
1
- ---
2
- title: TSLM - Time Series Language Model
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: BTC price prediction API based on Kronos model
10
- ---
11
-
12
  # Kronos BTC Prediction API
13
 
14
  基于 Kronos 时序预测模型的 BTC 价格预测 API 服务。
@@ -25,26 +14,196 @@ short_description: BTC price prediction API based on Kronos model
25
 
26
  ## API 端点
27
 
28
- - `GET /health` - 健康检查
29
- - `POST /predict` - 价格预测
30
- - `POST /signal` - 交易信号
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  ## 快速开始
33
 
 
 
34
  ```python
35
  from client import KronosClient
36
 
37
- client = KronosClient("https://xianqiu-tslm.hf.space")
 
38
 
39
- # 价格预测
40
  prediction = client.predict(ohlcv_data, pred_len=24)
41
- print(f"上涨概率: {prediction.upside_probability:.1%}")
42
 
43
- # 交易信号
44
  signal = client.get_signal(ohlcv_data)
45
- print(f"信号: {signal.signal}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ```
47
 
48
- ## API 文档
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- 访`/docs` 查看完整的 Swagger API 文档
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Kronos BTC Prediction API
2
 
3
  基于 Kronos 时序预测模型的 BTC 价格预测 API 服务。
 
14
 
15
  ## API 端点
16
 
17
+ ### 健康检查
18
+
19
+ ```
20
+ GET /health
21
+ ```
22
+
23
+ **响应示例:**
24
+ ```json
25
+ {
26
+ "status": "healthy",
27
+ "model_loaded": true,
28
+ "model_version": "iter5 (converged)",
29
+ "device": "cpu",
30
+ "timestamp": "2024-01-15T10:30:00.000000"
31
+ }
32
+ ```
33
+
34
+ ### 价格预测
35
+
36
+ ```
37
+ POST /predict
38
+ ```
39
+
40
+ 基于历史 OHLCV 数据预测未来价格走势。
41
+
42
+ **请求参数:**
43
+
44
+ | 字段 | 类型 | 必填 | 默认值 | 描述 |
45
+ |------|------|------|--------|------|
46
+ | data | List[OHLCVData] | Yes | - | 历史 K 线数据 (至少 100 条) |
47
+ | pred_len | int | No | 24 | 预测长度 (1-72 小时) |
48
+ | n_paths | int | No | 30 | Monte Carlo 路径数 (10-100) |
49
+ | temperature | float | No | 1.0 | 采样温度 (0.1-2.0) |
50
+ | top_p | float | No | 0.9 | Top-p 采样 (0.5-1.0) |
51
+
52
+ **OHLCVData 格式:**
53
+ ```json
54
+ {
55
+ "timestamp": "2024-01-15T10:00:00",
56
+ "open": 42000.0,
57
+ "high": 42500.0,
58
+ "low": 41800.0,
59
+ "close": 42300.0,
60
+ "volume": 1234.56,
61
+ "amount": 52000000.0 // 可选
62
+ }
63
+ ```
64
+
65
+ **响应示例:**
66
+ ```json
67
+ {
68
+ "current_price": 42300.0,
69
+ "mean_forecast": 42850.5,
70
+ "min_forecast": 41200.0,
71
+ "max_forecast": 44100.0,
72
+ "upside_probability": 0.65,
73
+ "expected_return": 0.013,
74
+ "volatility_amplification": 0.42,
75
+ "confidence": 0.78,
76
+ "forecast_prices": [42350.0, 42400.0, ...],
77
+ "timestamp": "2024-01-15T10:30:00.000000"
78
+ }
79
+ ```
80
+
81
+ ### 交易信号
82
+
83
+ ```
84
+ POST /signal
85
+ ```
86
+
87
+ 基于预测结果生成交易信号。
88
+
89
+ **请求参数:**
90
+
91
+ | 字段 | 类型 | 必填 | 默认值 | 描述 |
92
+ |------|------|------|--------|------|
93
+ | data | List[OHLCVData] | Yes | - | 历史 K 线数据 |
94
+ | buy_threshold | float | No | 0.58 | 买入阈值 (0.5-0.9) |
95
+ | sell_threshold | float | No | 0.42 | 卖出阈值 (0.1-0.5) |
96
+ | stop_loss | float | No | 0.03 | 止损比例 (0.01-0.1) |
97
+ | take_profit | float | No | 0.08 | 止盈比例 (0.02-0.2) |
98
+ | n_paths | int | No | 30 | Monte Carlo 路径数 |
99
+
100
+ **响应示例:**
101
+ ```json
102
+ {
103
+ "signal": "BUY",
104
+ "confidence": 0.78,
105
+ "current_price": 42300.0,
106
+ "target_price": 42850.5,
107
+ "stop_loss_price": 41031.0,
108
+ "take_profit_price": 45684.0,
109
+ "upside_probability": 0.65,
110
+ "expected_return": 0.013,
111
+ "suggested_position_size": 0.15,
112
+ "reason": "Upside probability 65.0% > 58%",
113
+ "timestamp": "2024-01-15T10:30:00.000000"
114
+ }
115
+ ```
116
+
117
+ **信号类型:**
118
+ - `STRONG_BUY`: 强烈买入 (上涨概率 > 70%, 低波动)
119
+ - `BUY`: 买入 (上涨概率 > buy_threshold)
120
+ - `HOLD`: 持有 (中性区间或低置信度)
121
+ - `SELL`: 卖出 (上涨概率 < sell_threshold)
122
+ - `STRONG_SELL`: 强烈卖出 (下跌概率 > 70%, 低波动)
123
 
124
  ## 快速开始
125
 
126
+ ### 使用 Python SDK
127
+
128
  ```python
129
  from client import KronosClient
130
 
131
+ # 连接到 API
132
+ client = KronosClient("https://your-space.hf.space")
133
 
134
+ # 获取价格预测
135
  prediction = client.predict(ohlcv_data, pred_len=24)
136
+ print(f"上涨概率: {prediction['upside_probability']:.1%}")
137
 
138
+ # 获取交易信号
139
  signal = client.get_signal(ohlcv_data)
140
+ print(f"信号: {signal['signal']}, 置信度: {signal['confidence']:.1%}")
141
+ ```
142
+
143
+ ### 使用 cURL
144
+
145
+ ```bash
146
+ # 健康检查
147
+ curl https://your-space.hf.space/health
148
+
149
+ # 价格预测
150
+ curl -X POST https://your-space.hf.space/predict \
151
+ -H "Content-Type: application/json" \
152
+ -d '{
153
+ "data": [...],
154
+ "pred_len": 24,
155
+ "n_paths": 30
156
+ }'
157
  ```
158
 
159
+ ## 本地部署
160
+
161
+ ```bash
162
+ # 安装依赖
163
+ pip install -r requirements.txt
164
+
165
+ # 启动服务
166
+ python app.py
167
+
168
+ # 服务将在 http://localhost:7860 启动
169
+ # API 文档: http://localhost:7860/docs
170
+ ```
171
+
172
+ ## HuggingFace Space 部署
173
+
174
+ 1. 创建新的 HuggingFace Space (选择 "Docker" 或 "Gradio" SDK)
175
+ 2. 上传所有文件:
176
+ ```
177
+ hf_space/
178
+ ├── app.py
179
+ ├── requirements.txt
180
+ ├── README.md
181
+ ├── client.py
182
+ ├── model/
183
+ │ ├── __init__.py
184
+ │ ├── kronos.py
185
+ │ └── module.py
186
+ └── models/
187
+ ├── tokenizer/
188
+ │ ├── config.json
189
+ │ └── model.safetensors
190
+ └── predictor/
191
+ ├── config.json
192
+ └── model.safetensors
193
+ ```
194
+ 3. Space 将自动构建和部署
195
+
196
+ ## 注意事项
197
+
198
+ - **最小数据要求**: 至少 100 条 OHLCV 数据点
199
+ - **时间间隔**: 建议使用 1 小时 K 线数据
200
+ - **CPU 推理**: HuggingFace Space 免费版使用 CPU,预测约需 5-10 秒
201
+ - **并发限制**: 免费版有请求频率限制,建议间隔 1 秒以上
202
+
203
+ ## 许可证
204
+
205
+ MIT License
206
+
207
+ ## 联系方式
208
 
209
+ 如有题或建议,请提交 Issue
client.py CHANGED
@@ -1,643 +1,410 @@
 
1
  """
2
- Kronos BTC Prediction API Client SDK
3
 
4
- 用于连Kronos BTC 预测 API 的 Python 客户端
5
 
6
- 使用示例:
7
- from client import KronosClient
 
8
 
9
- client = KronosClient("https://your-space.hf.space")
 
10
 
11
- # 健康检查
12
- health = client.health()
13
 
14
- # 价格预
15
- prediction = client.predict(ohlcv_data, pred_len=24)
16
 
17
- # 交易信号
18
- signal = client.get_signal(ohlcv_data)
19
  """
20
 
 
 
 
21
  import time
22
- from datetime import datetime
23
- from typing import List, Dict, Any, Optional, Union
24
- from dataclasses import dataclass
25
- from enum import Enum
26
 
27
- import httpx
28
- import pandas as pd
29
 
30
 
31
- class SignalType(str, Enum):
32
- """交易信号类型"""
33
- STRONG_BUY = "STRONG_BUY"
34
- BUY = "BUY"
35
- HOLD = "HOLD"
36
- SELL = "SELL"
37
- STRONG_SELL = "STRONG_SELL"
38
 
 
39
 
40
- @dataclass
41
- class OHLCVData:
42
- """OHLCV K线数据"""
43
- timestamp: str
44
- open: float
45
- high: float
46
- low: float
47
- close: float
48
- volume: float
49
- amount: Optional[float] = None
50
-
51
- def to_dict(self) -> Dict[str, Any]:
52
- return {
53
- "timestamp": self.timestamp,
54
- "open": self.open,
55
- "high": self.high,
56
- "low": self.low,
57
- "close": self.close,
58
- "volume": self.volume,
59
- "amount": self.amount
60
- }
61
 
62
 
63
- @dataclass
64
- class PredictResult:
65
- """预测结果"""
66
- current_price: float
67
- mean_forecast: float
68
- min_forecast: float
69
- max_forecast: float
70
- upside_probability: float
71
- expected_return: float
72
- volatility_amplification: float
73
- confidence: float
74
- forecast_prices: List[float]
75
- timestamp: str
76
-
77
- @classmethod
78
- def from_dict(cls, data: Dict[str, Any]) -> "PredictResult":
79
- return cls(**data)
80
-
81
- def __repr__(self) -> str:
82
- return (
83
- f"PredictResult(\n"
84
- f" current_price={self.current_price:.2f},\n"
85
- f" mean_forecast={self.mean_forecast:.2f},\n"
86
- f" upside_probability={self.upside_probability:.1%},\n"
87
- f" expected_return={self.expected_return:.2%},\n"
88
- f" confidence={self.confidence:.1%}\n"
89
- f")"
90
- )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- @dataclass
94
- class SignalResult:
95
- """交易信号结果"""
96
- signal: SignalType
97
- confidence: float
98
- current_price: float
99
- target_price: float
100
- stop_loss_price: float
101
- take_profit_price: float
102
- upside_probability: float
103
- expected_return: float
104
- suggested_position_size: float
105
- reason: str
106
- timestamp: str
107
-
108
- @classmethod
109
- def from_dict(cls, data: Dict[str, Any]) -> "SignalResult":
110
- data["signal"] = SignalType(data["signal"])
111
- return cls(**data)
112
-
113
- def __repr__(self) -> str:
114
- return (
115
- f"SignalResult(\n"
116
- f" signal={self.signal.value},\n"
117
- f" confidence={self.confidence:.1%},\n"
118
- f" current_price={self.current_price:.2f},\n"
119
- f" target_price={self.target_price:.2f},\n"
120
- f" stop_loss={self.stop_loss_price:.2f},\n"
121
- f" take_profit={self.take_profit_price:.2f},\n"
122
- f" position_size={self.suggested_position_size:.1%},\n"
123
- f" reason='{self.reason}'\n"
124
- f")"
125
- )
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- @dataclass
129
- class HealthResult:
130
- """健康检查结果"""
131
- status: str
132
- model_loaded: bool
133
- model_version: str
134
- device: str
135
- timestamp: str
136
-
137
- @classmethod
138
- def from_dict(cls, data: Dict[str, Any]) -> "HealthResult":
139
- return cls(**data)
140
 
 
 
 
 
 
 
 
141
 
142
- class KronosClientError(Exception):
143
- """Kronos 客户端错误"""
144
- pass
145
 
 
146
 
147
- class KronosClient:
148
- """
149
- Kronos BTC 预测 API 客户端
150
-
151
- Args:
152
- base_url: API 服务地址 (如 "https://your-space.hf.space")
153
- timeout: 请求超时时间 (秒)
154
- max_retries: 最大重试次数
155
-
156
- Examples:
157
- >>> client = KronosClient("https://your-space.hf.space")
158
- >>> health = client.health()
159
- >>> print(f"Status: {health.status}")
160
- """
161
-
162
- def __init__(
163
- self,
164
- base_url: str,
165
- timeout: float = 60.0,
166
- max_retries: int = 3
167
- ):
168
- self.base_url = base_url.rstrip("/")
169
- self.timeout = timeout
170
- self.max_retries = max_retries
171
- self._client = httpx.Client(timeout=timeout)
172
-
173
- def __enter__(self):
174
- return self
175
-
176
- def __exit__(self, *args):
177
- self.close()
178
-
179
- def close(self):
180
- """关闭客户端"""
181
- self._client.close()
182
-
183
- def _request(
184
- self,
185
- method: str,
186
- endpoint: str,
187
- json: Optional[Dict] = None,
188
- retry_count: int = 0
189
- ) -> Dict[str, Any]:
190
- """发送 HTTP 请求"""
191
- url = f"{self.base_url}{endpoint}"
192
-
193
- try:
194
- response = self._client.request(method, url, json=json)
195
-
196
- if response.status_code == 503:
197
- # 模型未加载,等待重试
198
- if retry_count < self.max_retries:
199
- time.sleep(5)
200
- return self._request(method, endpoint, json, retry_count + 1)
201
- raise KronosClientError("Model not loaded after retries")
202
-
203
- response.raise_for_status()
204
- return response.json()
205
-
206
- except httpx.ConnectError as e:
207
- raise KronosClientError(f"Connection failed: {e}")
208
- except httpx.TimeoutException as e:
209
- raise KronosClientError(f"Request timeout: {e}")
210
- except httpx.HTTPStatusError as e:
211
- raise KronosClientError(f"HTTP error {e.response.status_code}: {e.response.text}")
212
-
213
- def health(self) -> HealthResult:
214
- """
215
- 健康检查
216
-
217
- Returns:
218
- HealthResult: 健康状态
219
- """
220
- data = self._request("GET", "/health")
221
- return HealthResult.from_dict(data)
222
-
223
- def predict(
224
- self,
225
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame],
226
- pred_len: int = 24,
227
- n_paths: int = 30,
228
- temperature: float = 1.0,
229
- top_p: float = 0.9
230
- ) -> PredictResult:
231
- """
232
- 预测 BTC 价格走势
233
-
234
- Args:
235
- data: OHLCV 数据 (至少 100 条)
236
- - List[Dict]: 字典列表
237
- - List[OHLCVData]: OHLCVData 对象列表
238
- - pd.DataFrame: DataFrame (需包含 timestamp, open, high, low, close, volume 列)
239
- pred_len: 预测长度 (1-72 小时)
240
- n_paths: Monte Carlo 路径数 (10-100)
241
- temperature: 采样温度 (0.1-2.0)
242
- top_p: Top-p 采样 (0.5-1.0)
243
-
244
- Returns:
245
- PredictResult: 预测结果
246
-
247
- Examples:
248
- >>> result = client.predict(df, pred_len=24)
249
- >>> print(f"上涨概率: {result.upside_probability:.1%}")
250
- """
251
- ohlcv_list = self._convert_data(data)
252
-
253
- if len(ohlcv_list) < 100:
254
- raise KronosClientError(f"At least 100 data points required, got {len(ohlcv_list)}")
255
-
256
- request_data = {
257
- "data": ohlcv_list,
258
- "pred_len": pred_len,
259
- "n_paths": n_paths,
260
- "temperature": temperature,
261
- "top_p": top_p
262
- }
263
-
264
- response = self._request("POST", "/predict", json=request_data)
265
- return PredictResult.from_dict(response)
266
-
267
- def get_signal(
268
- self,
269
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame],
270
- buy_threshold: float = 0.58,
271
- sell_threshold: float = 0.42,
272
- stop_loss: float = 0.03,
273
- take_profit: float = 0.08,
274
- n_paths: int = 30
275
- ) -> SignalResult:
276
- """
277
- 获取交易信号
278
-
279
- Args:
280
- data: OHLCV 数据 (至少 100 条)
281
- buy_threshold: 买入阈值 (0.5-0.9)
282
- sell_threshold: 卖出阈值 (0.1-0.5)
283
- stop_loss: 止损比例 (0.01-0.1)
284
- take_profit: 止盈比例 (0.02-0.2)
285
- n_paths: Monte Carlo 路径数 (10-100)
286
-
287
- Returns:
288
- SignalResult: 交易信号
289
-
290
- Examples:
291
- >>> signal = client.get_signal(df)
292
- >>> if signal.signal == SignalType.BUY:
293
- ... print(f"买入! 目标价: {signal.target_price:.2f}")
294
- """
295
- ohlcv_list = self._convert_data(data)
296
-
297
- if len(ohlcv_list) < 100:
298
- raise KronosClientError(f"At least 100 data points required, got {len(ohlcv_list)}")
299
-
300
- request_data = {
301
- "data": ohlcv_list,
302
- "buy_threshold": buy_threshold,
303
- "sell_threshold": sell_threshold,
304
- "stop_loss": stop_loss,
305
- "take_profit": take_profit,
306
- "n_paths": n_paths
307
- }
308
-
309
- response = self._request("POST", "/signal", json=request_data)
310
- return SignalResult.from_dict(response)
311
-
312
- def _convert_data(
313
- self,
314
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame]
315
- ) -> List[Dict[str, Any]]:
316
- """转换数据格式"""
317
- if isinstance(data, pd.DataFrame):
318
- return self._dataframe_to_list(data)
319
- elif isinstance(data, list):
320
- if len(data) == 0:
321
- return []
322
- if isinstance(data[0], OHLCVData):
323
- return [d.to_dict() for d in data]
324
- elif isinstance(data[0], dict):
325
- return data
326
-
327
- raise KronosClientError(f"Unsupported data type: {type(data)}")
328
-
329
- def _dataframe_to_list(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
330
- """将 DataFrame 转换为列表"""
331
- required_cols = ["open", "high", "low", "close", "volume"]
332
- for col in required_cols:
333
- if col not in df.columns:
334
- raise KronosClientError(f"Missing required column: {col}")
335
-
336
- result = []
337
- for _, row in df.iterrows():
338
- # 处理时间戳
339
- if "timestamp" in df.columns:
340
- ts = row["timestamp"]
341
- if isinstance(ts, pd.Timestamp):
342
- ts = ts.isoformat()
343
- elif isinstance(ts, datetime):
344
- ts = ts.isoformat()
345
- else:
346
- ts = str(ts)
347
- else:
348
- ts = datetime.utcnow().isoformat()
349
 
350
- result.append({
351
- "timestamp": ts,
352
- "open": float(row["open"]),
353
- "high": float(row["high"]),
354
- "low": float(row["low"]),
355
- "close": float(row["close"]),
356
- "volume": float(row["volume"]),
357
- "amount": float(row["amount"]) if "amount" in df.columns else None
358
- })
359
-
360
- return result
361
 
362
 
363
- class AsyncKronosClient:
364
- """
365
- 异步 Kronos BTC 预测 API 客户端
366
-
367
- 用于异步场景 (如 asyncio 应用)
368
-
369
- Examples:
370
- >>> async with AsyncKronosClient("https://your-space.hf.space") as client:
371
- ... health = await client.health()
372
- ... prediction = await client.predict(df)
373
- """
374
-
375
- def __init__(
376
- self,
377
- base_url: str,
378
- timeout: float = 60.0,
379
- max_retries: int = 3
380
- ):
381
- self.base_url = base_url.rstrip("/")
382
- self.timeout = timeout
383
- self.max_retries = max_retries
384
- self._client: Optional[httpx.AsyncClient] = None
385
-
386
- async def __aenter__(self):
387
- self._client = httpx.AsyncClient(timeout=self.timeout)
388
- return self
389
-
390
- async def __aexit__(self, *args):
391
- await self.close()
392
-
393
- async def close(self):
394
- """关闭客户端"""
395
- if self._client:
396
- await self._client.aclose()
397
- self._client = None
398
-
399
- async def _request(
400
- self,
401
- method: str,
402
- endpoint: str,
403
- json: Optional[Dict] = None,
404
- retry_count: int = 0
405
- ) -> Dict[str, Any]:
406
- """发送 HTTP 请求"""
407
- if not self._client:
408
- raise KronosClientError("Client not initialized. Use 'async with' context.")
409
-
410
- url = f"{self.base_url}{endpoint}"
411
-
412
- try:
413
- response = await self._client.request(method, url, json=json)
414
 
415
- if response.status_code == 503:
416
- if retry_count < self.max_retries:
417
- import asyncio
418
- await asyncio.sleep(5)
419
- return await self._request(method, endpoint, json, retry_count + 1)
420
- raise KronosClientError("Model not loaded after retries")
421
 
422
- response.raise_for_status()
423
- return response.json()
424
-
425
- except httpx.ConnectError as e:
426
- raise KronosClientError(f"Connection failed: {e}")
427
- except httpx.TimeoutException as e:
428
- raise KronosClientError(f"Request timeout: {e}")
429
- except httpx.HTTPStatusError as e:
430
- raise KronosClientError(f"HTTP error {e.response.status_code}: {e.response.text}")
431
-
432
- async def health(self) -> HealthResult:
433
- """健康检查"""
434
- data = await self._request("GET", "/health")
435
- return HealthResult.from_dict(data)
436
-
437
- async def predict(
438
- self,
439
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame],
440
- pred_len: int = 24,
441
- n_paths: int = 30,
442
- temperature: float = 1.0,
443
- top_p: float = 0.9
444
- ) -> PredictResult:
445
- """预测 BTC 价格走势"""
446
- ohlcv_list = self._convert_data(data)
447
-
448
- if len(ohlcv_list) < 100:
449
- raise KronosClientError(f"At least 100 data points required")
450
-
451
- request_data = {
452
- "data": ohlcv_list,
453
- "pred_len": pred_len,
454
- "n_paths": n_paths,
455
- "temperature": temperature,
456
- "top_p": top_p
457
- }
458
-
459
- response = await self._request("POST", "/predict", json=request_data)
460
- return PredictResult.from_dict(response)
461
-
462
- async def get_signal(
463
- self,
464
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame],
465
- buy_threshold: float = 0.58,
466
- sell_threshold: float = 0.42,
467
- stop_loss: float = 0.03,
468
- take_profit: float = 0.08,
469
- n_paths: int = 30
470
- ) -> SignalResult:
471
- """获取交易信号"""
472
- ohlcv_list = self._convert_data(data)
473
-
474
- if len(ohlcv_list) < 100:
475
- raise KronosClientError(f"At least 100 data points required")
476
-
477
- request_data = {
478
- "data": ohlcv_list,
479
- "buy_threshold": buy_threshold,
480
- "sell_threshold": sell_threshold,
481
- "stop_loss": stop_loss,
482
- "take_profit": take_profit,
483
- "n_paths": n_paths
484
- }
485
-
486
- response = await self._request("POST", "/signal", json=request_data)
487
- return SignalResult.from_dict(response)
488
-
489
- def _convert_data(
490
- self,
491
- data: Union[List[Dict], List[OHLCVData], pd.DataFrame]
492
- ) -> List[Dict[str, Any]]:
493
- """转换数据格式"""
494
- if isinstance(data, pd.DataFrame):
495
- return self._dataframe_to_list(data)
496
- elif isinstance(data, list):
497
- if len(data) == 0:
498
- return []
499
- if isinstance(data[0], OHLCVData):
500
- return [d.to_dict() for d in data]
501
- elif isinstance(data[0], dict):
502
- return data
503
-
504
- raise KronosClientError(f"Unsupported data type: {type(data)}")
505
-
506
- def _dataframe_to_list(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
507
- """将 DataFrame 转换为列表"""
508
- required_cols = ["open", "high", "low", "close", "volume"]
509
- for col in required_cols:
510
- if col not in df.columns:
511
- raise KronosClientError(f"Missing required column: {col}")
512
-
513
- result = []
514
- for _, row in df.iterrows():
515
- if "timestamp" in df.columns:
516
- ts = row["timestamp"]
517
- if isinstance(ts, pd.Timestamp):
518
- ts = ts.isoformat()
519
- elif isinstance(ts, datetime):
520
- ts = ts.isoformat()
521
- else:
522
- ts = str(ts)
523
- else:
524
- ts = datetime.utcnow().isoformat()
525
 
526
- result.append({
527
- "timestamp": ts,
528
- "open": float(row["open"]),
529
- "high": float(row["high"]),
530
- "low": float(row["low"]),
531
- "close": float(row["close"]),
532
- "volume": float(row["volume"]),
533
- "amount": float(row["amount"]) if "amount" in df.columns else None
534
- })
535
-
536
- return result
537
 
538
 
539
- # ==================== 使用示例 ====================
540
-
541
- if __name__ == "__main__":
542
- import asyncio
543
-
544
- # 示例数据 (实际使用时替换为真实数据)
545
- def create_sample_data() -> pd.DataFrame:
546
- """创建示例数据"""
547
- import numpy as np
548
-
549
- np.random.seed(42)
550
- n = 120
551
-
552
- dates = pd.date_range(start="2024-01-01", periods=n, freq="1h")
553
- base_price = 42000
554
-
555
- prices = [base_price]
556
- for i in range(1, n):
557
- change = np.random.randn() * 100
558
- prices.append(prices[-1] + change)
559
-
560
- df = pd.DataFrame({
561
- "timestamp": dates,
562
- "open": prices,
563
- "high": [p + np.random.rand() * 50 for p in prices],
564
- "low": [p - np.random.rand() * 50 for p in prices],
565
- "close": [p + np.random.randn() * 20 for p in prices],
566
- "volume": np.random.rand(n) * 1000 + 100
567
- })
568
-
569
- return df
570
-
571
- # 同步示例
572
- def sync_example():
573
- print("=== 同步客户端示例 ===")
574
-
575
- # 创建客户端
576
- client = KronosClient("http://localhost:7860")
577
-
578
- try:
579
- # 健康检查
580
- health = client.health()
581
- print(f"Status: {health.status}")
582
- print(f"Model loaded: {health.model_loaded}")
583
 
584
- # 创建示例数据
585
- df = create_sample_data()
586
- print(f"\n数据点数: {len(df)}")
 
 
 
 
 
587
 
588
- # 预测
589
- prediction = client.predict(df, pred_len=24)
590
- print(f"\n预测结果:")
591
- print(prediction)
 
 
 
 
 
 
 
592
 
593
- # 交易信号
594
- signal = client.get_signal(df)
595
- print(f"\n交易信号:")
596
- print(signal)
 
 
 
 
597
 
598
- except KronosClientError as e:
599
- print(f"Error: {e}")
600
- finally:
601
- client.close()
602
-
603
- # 异步示例
604
- async def async_example():
605
- print("\n=== 异步客户端示例 ===")
606
-
607
- async with AsyncKronosClient("http://localhost:7860") as client:
608
- try:
609
- # 健康检查
610
- health = await client.health()
611
- print(f"Status: {health.status}")
612
-
613
- # 创建示例数据
614
- df = create_sample_data()
615
-
616
- # 并发预测和信号
617
- prediction, signal = await asyncio.gather(
618
- client.predict(df),
619
- client.get_signal(df)
620
- )
621
-
622
- print(f"\n预测结果:")
623
- print(prediction)
624
- print(f"\n交易信号:")
625
- print(signal)
626
-
627
- except KronosClientError as e:
628
- print(f"Error: {e}")
629
-
630
- # 运行示例
631
- print("Kronos Client SDK 示例\n")
632
- print("注意: 请确保 API 服务已启动 (python app.py)\n")
633
-
634
- # 仅打印帮助信息,不实际运行
635
- print("同步使用:")
636
- print(" client = KronosClient('http://localhost:7860')")
637
- print(" prediction = client.predict(df)")
638
- print(" signal = client.get_signal(df)")
639
- print()
640
- print("异步使用:")
641
- print(" async with AsyncKronosClient('http://localhost:7860') as client:")
642
- print(" prediction = await client.predict(df)")
643
- print(" signal = await client.get_signal(df)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
  """
3
+ Kronos BTC 预测 API 测试客户端
4
 
5
+ 可直运行来验证 HuggingFace Space API 是否正常工作
6
 
7
+ 使用方法:
8
+ # 测试健康检查
9
+ python client.py health
10
 
11
+ # 测试预测 API
12
+ python client.py predict
13
 
14
+ # 测试交易信号 API
15
+ python client.py signal
16
 
17
+ # 运行所有
18
+ python client.py all
19
 
20
+ # 使用自定义 URL
21
+ python client.py all --url https://your-space.hf.space
22
  """
23
 
24
+ import argparse
25
+ import json
26
+ import sys
27
  import time
28
+ from datetime import datetime, timedelta
29
+ from typing import List, Dict, Any, Optional
 
 
30
 
31
+ import requests
 
32
 
33
 
34
+ # ==================== 配置 ====================
 
 
 
 
 
 
35
 
36
+ DEFAULT_API_URL = "https://xianqiu-tslm.hf.space"
37
 
38
+ # 币安 API
39
+ BINANCE_API = "https://api.binance.com/api/v3/klines"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
+ # ==================== 辅助函数 ====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def fetch_btc_data(symbol: str = "BTCUSDT", interval: str = "1h", limit: int = 200) -> List[Dict]:
45
+ """
46
+ 从币安获取 BTC K线数据
47
+
48
+ Args:
49
+ symbol: 交易对
50
+ interval: K线周期 (1h, 4h, 1d 等)
51
+ limit: 获取条数 (最大 1000)
52
+
53
+ Returns:
54
+ OHLCV 数据列表
55
+ """
56
+ print(f"[Binance] 获取 {symbol} {interval} K线数据 (最近 {limit} 条)...")
57
+
58
+ params = {
59
+ "symbol": symbol,
60
+ "interval": interval,
61
+ "limit": limit
62
+ }
63
+
64
+ try:
65
+ response = requests.get(BINANCE_API, params=params, timeout=10)
66
+ response.raise_for_status()
67
+ data = response.json()
68
+ except requests.exceptions.RequestException as e:
69
+ print(f"[Error] 无法连接币安 API: {e}")
70
+ print("[Info] 使用模拟数据...")
71
+ return generate_mock_data(limit)
72
+
73
+ ohlcv_list = []
74
+ for item in data:
75
+ ohlcv_list.append({
76
+ "timestamp": datetime.fromtimestamp(item[0] / 1000).isoformat(),
77
+ "open": float(item[1]),
78
+ "high": float(item[2]),
79
+ "low": float(item[3]),
80
+ "close": float(item[4]),
81
+ "volume": float(item[5]),
82
+ "amount": float(item[7]) # Quote asset volume
83
+ })
84
+
85
+ print(f"[OK] 获取到 {len(ohlcv_list)} 条数据")
86
+ print(f" 时间范围: {ohlcv_list[0]['timestamp']} ~ {ohlcv_list[-1]['timestamp']}")
87
+ print(f" 当前价格: ${ohlcv_list[-1]['close']:,.2f}")
88
+
89
+ return ohlcv_list
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ def generate_mock_data(n: int = 200) -> List[Dict]:
93
+ """生成模拟 K线数据 (当币安 API 不可用时使用)"""
94
+ import random
95
+
96
+ base_price = 100000.0
97
+ data = []
98
+ current_time = datetime.utcnow() - timedelta(hours=n)
99
+
100
+ for i in range(n):
101
+ change = random.gauss(0, 0.01) # 1% 标准差
102
+ base_price *= (1 + change)
103
+
104
+ high = base_price * (1 + random.random() * 0.005)
105
+ low = base_price * (1 - random.random() * 0.005)
106
+ close = random.uniform(low, high)
107
+
108
+ data.append({
109
+ "timestamp": current_time.isoformat(),
110
+ "open": round(base_price, 2),
111
+ "high": round(high, 2),
112
+ "low": round(low, 2),
113
+ "close": round(close, 2),
114
+ "volume": round(random.uniform(100, 1000), 2),
115
+ "amount": round(random.uniform(1000000, 10000000), 2)
116
+ })
117
+
118
+ current_time += timedelta(hours=1)
119
+
120
+ return data
121
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def print_json(data: Any, title: str = None):
124
+ """美化打印 JSON"""
125
+ if title:
126
+ print(f"\n{'='*60}")
127
+ print(f" {title}")
128
+ print(f"{'='*60}")
129
+ print(json.dumps(data, indent=2, ensure_ascii=False))
130
 
 
 
 
131
 
132
+ # ==================== API 测试函数 ====================
133
 
134
+ def test_health(base_url: str) -> bool:
135
+ """测试健康检查 API"""
136
+ print("\n" + "="*60)
137
+ print(" TEST: /health")
138
+ print("="*60)
139
+
140
+ url = f"{base_url}/health"
141
+ print(f"[Request] GET {url}")
142
+
143
+ try:
144
+ start = time.time()
145
+ response = requests.get(url, timeout=30)
146
+ elapsed = time.time() - start
147
+
148
+ print(f"[Response] Status: {response.status_code} ({elapsed:.2f}s)")
149
+
150
+ if response.status_code == 200:
151
+ data = response.json()
152
+ print(f"\n[Result]")
153
+ print(f" Status: {data.get('status', 'N/A')}")
154
+ print(f" Model Loaded: {data.get('model_loaded', 'N/A')}")
155
+ print(f" Model Version: {data.get('model_version', 'N/A')}")
156
+ print(f" Device: {data.get('device', 'N/A')}")
157
+ print(f" Timestamp: {data.get('timestamp', 'N/A')}")
158
+ return True
159
+ else:
160
+ print(f"[Error] {response.text}")
161
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ except requests.exceptions.RequestException as e:
164
+ print(f"[Error] 请求失败: {e}")
165
+ return False
 
 
 
 
 
 
 
 
166
 
167
 
168
+ def test_predict(base_url: str, data: List[Dict] = None) -> bool:
169
+ """测试预测 API"""
170
+ print("\n" + "="*60)
171
+ print(" TEST: /predict")
172
+ print("="*60)
173
+
174
+ # 获取数据
175
+ if data is None:
176
+ data = fetch_btc_data(limit=200)
177
+
178
+ url = f"{base_url}/predict"
179
+ payload = {
180
+ "data": data,
181
+ "pred_len": 24,
182
+ "n_paths": 30,
183
+ "temperature": 1.0,
184
+ "top_p": 0.9
185
+ }
186
+
187
+ print(f"\n[Request] POST {url}")
188
+ print(f" 数据点数: {len(data)}")
189
+ print(f" 预测长度: {payload['pred_len']} 小时")
190
+ print(f" Monte Carlo: {payload['n_paths']} 路径")
191
+
192
+ try:
193
+ start = time.time()
194
+ response = requests.post(url, json=payload, timeout=120)
195
+ elapsed = time.time() - start
196
+
197
+ print(f"\n[Response] Status: {response.status_code} ({elapsed:.2f}s)")
198
+
199
+ if response.status_code == 200:
200
+ result = response.json()
201
+ print(f"\n[Result]")
202
+ print(f" 当前价格: ${result.get('current_price', 0):,.2f}")
203
+ print(f" 预测均值: ${result.get('mean_forecast', 0):,.2f}")
204
+ print(f" 预测范围: ${result.get('min_forecast', 0):,.2f} ~ ${result.get('max_forecast', 0):,.2f}")
205
+ print(f" 上涨概率: {result.get('upside_probability', 0)*100:.1f}%")
206
+ print(f" 预期收益: {result.get('expected_return', 0)*100:.2f}%")
207
+ print(f" 波动放大: {result.get('volatility_amplification', 0):.2f}x")
208
+ print(f" 置信度: {result.get('confidence', 0)*100:.1f}%")
209
+ print(f" 预测点数: {len(result.get('forecast_prices', []))} 个")
 
 
 
 
 
 
 
 
 
210
 
211
+ # 显示部分预测价格
212
+ prices = result.get('forecast_prices', [])
213
+ if prices:
214
+ print(f"\n 预测价格趋势 (每6小时):")
215
+ for i in range(0, len(prices), 6):
216
+ print(f" +{i}h: ${prices[i]:,.2f}")
217
 
218
+ return True
219
+ elif response.status_code == 503:
220
+ print(f"[Warning] 模型未加载,请稍后重试")
221
+ print(f" Response: {response.text}")
222
+ return False
223
+ else:
224
+ print(f"[Error] {response.text}")
225
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ except requests.exceptions.Timeout:
228
+ print(f"[Error] 请求超时 (>120s)")
229
+ return False
230
+ except requests.exceptions.RequestException as e:
231
+ print(f"[Error] 请求失败: {e}")
232
+ return False
 
 
 
 
 
233
 
234
 
235
+ def test_signal(base_url: str, data: List[Dict] = None) -> bool:
236
+ """测试交易信号 API"""
237
+ print("\n" + "="*60)
238
+ print(" TEST: /signal")
239
+ print("="*60)
240
+
241
+ # 获取数据
242
+ if data is None:
243
+ data = fetch_btc_data(limit=200)
244
+
245
+ url = f"{base_url}/signal"
246
+ payload = {
247
+ "data": data,
248
+ "buy_threshold": 0.58,
249
+ "sell_threshold": 0.42,
250
+ "stop_loss": 0.03,
251
+ "take_profit": 0.08,
252
+ "n_paths": 30
253
+ }
254
+
255
+ print(f"\n[Request] POST {url}")
256
+ print(f" 数据点数: {len(data)}")
257
+ print(f" 买入阈值: {payload['buy_threshold']}")
258
+ print(f" 卖出阈值: {payload['sell_threshold']}")
259
+ print(f" 止损比例: {payload['stop_loss']*100:.1f}%")
260
+ print(f" 止盈比例: {payload['take_profit']*100:.1f}%")
261
+
262
+ try:
263
+ start = time.time()
264
+ response = requests.post(url, json=payload, timeout=120)
265
+ elapsed = time.time() - start
266
+
267
+ print(f"\n[Response] Status: {response.status_code} ({elapsed:.2f}s)")
268
+
269
+ if response.status_code == 200:
270
+ result = response.json()
271
+ signal = result.get('signal', 'N/A')
 
 
 
 
 
 
 
272
 
273
+ # 信号颜色
274
+ signal_icons = {
275
+ 'STRONG_BUY': '[++]',
276
+ 'BUY': '[+]',
277
+ 'HOLD': '[=]',
278
+ 'SELL': '[-]',
279
+ 'STRONG_SELL': '[--]'
280
+ }
281
 
282
+ print(f"\n[Result]")
283
+ print(f" 信号: {signal_icons.get(signal, '')} {signal}")
284
+ print(f" 置信度: {result.get('confidence', 0)*100:.1f}%")
285
+ print(f" 当前价格: ${result.get('current_price', 0):,.2f}")
286
+ print(f" 目标价格: ${result.get('target_price', 0):,.2f}")
287
+ print(f" 止损价格: ${result.get('stop_loss_price', 0):,.2f}")
288
+ print(f" 止盈价格: ${result.get('take_profit_price', 0):,.2f}")
289
+ print(f" 上涨概率: {result.get('upside_probability', 0)*100:.1f}%")
290
+ print(f" 预期收益: {result.get('expected_return', 0)*100:.2f}%")
291
+ print(f" 建议仓位: {result.get('suggested_position_size', 0)*100:.1f}%")
292
+ print(f" 原因: {result.get('reason', 'N/A')}")
293
 
294
+ return True
295
+ elif response.status_code == 503:
296
+ print(f"[Warning] 模型未加载,请稍后重试")
297
+ print(f" Response: {response.text}")
298
+ return False
299
+ else:
300
+ print(f"[Error] {response.text}")
301
+ return False
302
 
303
+ except requests.exceptions.Timeout:
304
+ print(f"[Error] 请求超时 (>120s)")
305
+ return False
306
+ except requests.exceptions.RequestException as e:
307
+ print(f"[Error] 请求失败: {e}")
308
+ return False
309
+
310
+
311
+ def run_all_tests(base_url: str) -> bool:
312
+ """运行所有测试"""
313
+ print("\n" + "#"*60)
314
+ print("#")
315
+ print("# Kronos BTC 预测 API 测试")
316
+ print(f"# URL: {base_url}")
317
+ print(f"# 时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
318
+ print("#")
319
+ print("#"*60)
320
+
321
+ results = {}
322
+
323
+ # 1. 健康检查
324
+ results['health'] = test_health(base_url)
325
+
326
+ if not results['health']:
327
+ print("\n[Warning] 健康检查失败,API 可能未启动")
328
+ print(" 请检查 HuggingFace Space 是否正在运行")
329
+ return False
330
+
331
+ # 2. 获取数据 (只获取一次,两个测试共用)
332
+ print("\n" + "-"*60)
333
+ data = fetch_btc_data(limit=200)
334
+
335
+ # 3. 预测测试
336
+ results['predict'] = test_predict(base_url, data)
337
+
338
+ # 4. 信号测试
339
+ results['signal'] = test_signal(base_url, data)
340
+
341
+ # 汇总
342
+ print("\n" + "="*60)
343
+ print(" 测试结果汇总")
344
+ print("="*60)
345
+
346
+ for test_name, passed in results.items():
347
+ status = "PASS" if passed else "FAIL"
348
+ icon = "[OK]" if passed else "[X]"
349
+ print(f" {icon} {test_name}: {status}")
350
+
351
+ all_passed = all(results.values())
352
+
353
+ print("\n" + "-"*60)
354
+ if all_passed:
355
+ print(" 所有测试通过!")
356
+ else:
357
+ print(" 部分测试失败,请检查 API 状态")
358
+ print("-"*60)
359
+
360
+ return all_passed
361
+
362
+
363
+ # ==================== 主函数 ====================
364
+
365
+ def main():
366
+ parser = argparse.ArgumentParser(
367
+ description="Kronos BTC 预测 API 测试客户端",
368
+ formatter_class=argparse.RawDescriptionHelpFormatter,
369
+ epilog="""
370
+ 示例:
371
+ python client.py health # 测试健康检查
372
+ python client.py predict # 测试预测 API
373
+ python client.py signal # 测试交易信号 API
374
+ python client.py all # 运行所有测试
375
+ python client.py all --url http://localhost:7860 # 测试本地服务
376
+ """
377
+ )
378
+
379
+ parser.add_argument(
380
+ "command",
381
+ choices=["health", "predict", "signal", "all"],
382
+ help="要执行的测试命令"
383
+ )
384
+
385
+ parser.add_argument(
386
+ "--url",
387
+ default=DEFAULT_API_URL,
388
+ help=f"API 地址 (默认: {DEFAULT_API_URL})"
389
+ )
390
+
391
+ args = parser.parse_args()
392
+
393
+ # 执行测试
394
+ if args.command == "health":
395
+ success = test_health(args.url)
396
+ elif args.command == "predict":
397
+ success = test_predict(args.url)
398
+ elif args.command == "signal":
399
+ success = test_signal(args.url)
400
+ elif args.command == "all":
401
+ success = run_all_tests(args.url)
402
+ else:
403
+ parser.print_help()
404
+ sys.exit(1)
405
+
406
+ sys.exit(0 if success else 1)
407
+
408
+
409
+ if __name__ == "__main__":
410
+ main()
models/predictor/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28271e7c248da93c5954ee72eb7c066b986958ba5159f0be4db5c8a423ab5d74
3
  size 16440776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee1a282b63487e18f0b0c2fea391a4ea335ee79e61708fecd5e2ac1d37eb5644
3
  size 16440776
models/tokenizer/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:983c247a275272480aeb6d6b0cb21733fc26d1315aed6e7a4cd2ea0590bbaef5
3
  size 15842376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0198724df098fd7f6088ed39277089afa0150c6f533c0ce067f4483a8e8ba6a7
3
  size 15842376
requirements.txt CHANGED
@@ -10,7 +10,6 @@ pydantic==2.5.2
10
  torch==2.1.0
11
  numpy>=1.24.0,<2.0.0
12
  pandas>=2.0.0
13
- einops>=0.7.0
14
 
15
  # Model loading
16
  safetensors>=0.4.0
 
10
  torch==2.1.0
11
  numpy>=1.24.0,<2.0.0
12
  pandas>=2.0.0
 
13
 
14
  # Model loading
15
  safetensors>=0.4.0