Trae Assistant
后端:新增 /api/upload,完善异常处理与中文文案;本地验证通过。
7afd68c
from flask import Flask, render_template, request, jsonify, send_from_directory
import numpy as np
from scipy import stats
import os
import json
from werkzeug.exceptions import RequestEntityTooLarge
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 6 * 1024 * 1024
@app.route('/')
def index():
return render_template('index.html')
@app.route('/health')
def health():
return "OK", 200
@app.route('/api/optimize', methods=['POST'])
def optimize():
"""
基于历史价格-销量数据估计线性需求曲线并推荐最优价格。
输入: JSON { "history": [ {"price": 10, "sales": 50}, ... ], "mc": 5 }
输出: JSON { "elasticity": -1.5, "optimal_price": 12.5 }
"""
data = request.json
history = data.get('history', [])
marginal_cost = data.get('mc', 0)
if len(history) < 3:
return jsonify({"error": "至少需要 3 条数据点"}), 400
prices = np.array([h['price'] for h in history])
sales = np.array([h['sales'] for h in history])
# Simple Linear Regression: Q = a + bP
# b should be negative (Law of Demand)
slope, intercept, r_value, p_value, std_err = stats.linregress(prices, sales)
if slope >= 0:
return jsonify({
"status": "warning",
"message": "需求曲线异常或过于平坦(斜率≥0),无法优化,建议保持当前价格。",
"slope": slope,
"optimal_price": prices[-1]
})
# Profit = (P - MC) * Q = (P - MC) * (a + bP)
# Profit = aP + bP^2 - aMC - bP(MC)
# Profit = bP^2 + (a - bMC)P - aMC
# dProfit/dP = 2bP + (a - bMC) = 0
# P_opt = -(a - bMC) / (2b) = (bMC - a) / 2b = MC/2 - a/2b
optimal_price = (slope * marginal_cost - intercept) / (2 * slope)
# Ensure optimal price is within reasonable bounds (not negative)
optimal_price = max(0, optimal_price)
# Calculate Price Elasticity at average price point
avg_p = np.mean(prices)
avg_q = slope * avg_p + intercept
elasticity = (slope * avg_p) / avg_q if avg_q != 0 else 0
return jsonify({
"status": "success",
"slope": slope,
"intercept": intercept,
"r_squared": r_value**2,
"optimal_price": round(optimal_price, 2),
"elasticity": round(elasticity, 2),
"message": "线性需求模型优化成功,已生成建议价格。"
})
@app.route('/api/upload', methods=['POST'])
def upload():
"""
上传并解析历史数据 JSON 文件,返回标准化的 history。
支持两种格式:
1) { "history": [ {price, sales, ...}, ... ] }
2) [ {price, sales, ...}, ... ]
约束:文件大小 ≤ 5MB,必须为 UTF-8 文本 JSON。
"""
if 'file' not in request.files:
return jsonify({"error": "未找到文件字段"}), 400
file = request.files['file']
filename = (file.filename or '').lower()
if not filename.endswith('.json'):
return jsonify({"error": "只支持 JSON 文件"}), 400
# 读取并校验大小
file.seek(0, os.SEEK_END)
size = file.tell()
file.seek(0)
if size > 5 * 1024 * 1024:
return jsonify({"error": "文件大小不能超过 5MB"}), 400
try:
raw = file.read()
try:
text = raw.decode('utf-8')
except UnicodeDecodeError:
return jsonify({"error": "文件内容不是有效的 UTF-8 文本"}), 400
payload = json.loads(text)
if isinstance(payload, dict) and 'history' in payload:
incoming = payload['history']
elif isinstance(payload, list):
incoming = payload
else:
return jsonify({"error": "JSON 格式不正确,应为 {history: [...]} 或数组"}), 400
cleaned = []
day_counter = 0
for item in incoming:
try:
price = float(item.get('price'))
sales = int(item.get('sales'))
except Exception:
# 跳过不可解析的数据行
continue
if price <= 0 or sales < 0:
continue
day_counter += 1
revenue = item.get('revenue')
if revenue is None:
revenue = price * sales
competitor_price = item.get('competitorPrice', 50)
cleaned.append({
"day": item.get('day', day_counter),
"price": price,
"sales": sales,
"revenue": revenue,
"competitorPrice": competitor_price
})
if not cleaned:
return jsonify({"error": "数据为空或格式不正确"}), 400
return jsonify({"history": cleaned})
except json.JSONDecodeError:
return jsonify({"error": "JSON 解析失败,请检查文件内容"}), 400
except Exception as e:
return jsonify({"error": f"服务器解析失败: {str(e)}"}), 500
@app.errorhandler(404)
def handle_404(e):
if request.path.startswith('/api/'):
return jsonify({"error": "未找到接口", "path": request.path}), 404
return render_template('index.html'), 404
@app.errorhandler(500)
def handle_500(e):
if request.path.startswith('/api/'):
return jsonify({"error": "服务器内部错误"}), 500
return render_template('index.html'), 500
@app.errorhandler(RequestEntityTooLarge)
def handle_file_too_large(e):
return jsonify({"error": "文件过大,服务端限制为 6MB"}), 413
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)