Trae Assistant
Localization: Update error messages to Chinese
71bca6f
import os
import logging
import numpy as np
import pandas as pd
from scipy.optimize import minimize_scalar
from flask import Flask, render_template, send_from_directory, request, jsonify
from werkzeug.utils import secure_filename
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False # 支持中文 JSON
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 限制上传文件大小为 16MB
app.config['UPLOAD_FOLDER'] = 'uploads'
# 确保上传目录存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
@app.route('/')
def index():
# 使用 render_template 渲染模板,前端需注意 Vue 分隔符冲突问题
return render_template('index.html')
@app.route('/static/<path:path>')
def send_static(path):
return send_from_directory('static', path)
def calculate_demand_linear(price, base_price, base_demand, elasticity):
"""
线性需求模型: D = a - bP
在基准点 (base_price, base_demand) 处的弹性为 elasticity
"""
# 弹性定义: E = (dD/dP) * (P/D)
# 线性: dD/dP = -b
# E = -b * (base_price / base_demand)
# -> b = -E * (base_demand / base_price)
# -> a = base_demand + b * base_price
# 确保弹性为负值 (输入如果是正数则取反)
e = -abs(elasticity)
b = -e * (base_demand / base_price)
a = base_demand + b * base_price
demand = a - b * price
return np.maximum(demand, 0) # 需求不能为负
def calculate_demand_constant_elasticity(price, base_price, base_demand, elasticity):
"""
恒定弹性模型 (指数模型): D = A * P^e
"""
# base_demand = A * base_price^e
# -> A = base_demand / (base_price^e)
e = -abs(elasticity)
A = base_demand / (base_price ** e)
demand = A * (price ** e)
return demand
def run_optimization(cost, base_price, base_demand, elasticity, model_type='linear'):
"""
运行价格优化算法
"""
# 模拟范围 (0.5x 成本 到 3x 成本,或者基准价格的 2 倍)
min_p = max(cost * 0.8, 0.01)
max_p = max(base_price * 2.0, cost * 3.0)
prices = np.linspace(min_p, max_p, 100)
results = []
max_profit = -float('inf')
optimal_price = base_price
# 第一次粗略扫描
for p in prices:
if model_type == 'constant':
d = calculate_demand_constant_elasticity(p, base_price, base_demand, elasticity)
else:
d = calculate_demand_linear(p, base_price, base_demand, elasticity)
revenue = p * d
profit = (p - cost) * d
if profit > max_profit:
max_profit = profit
optimal_price = p
results.append({
'price': round(p, 2),
'demand': round(d, 2),
'revenue': round(revenue, 2),
'profit': round(profit, 2)
})
# 第二次精细优化 (SciPy minimize_scalar)
# 确定搜索上界
zero_demand_price = max_p
if model_type == 'linear':
# a - bP = 0 -> P = a/b
e_val = -abs(elasticity)
b_val = -e_val * (base_demand / base_price)
if b_val != 0:
a_val = base_demand + b_val * base_price
zero_demand_price = a_val / b_val
upper_bound = min(max_p * 1.5, zero_demand_price * 1.1 if model_type == 'linear' else max_p * 1.5)
def profit_func(p):
if p <= 0: return 1e9 # 惩罚项
if model_type == 'constant':
d = calculate_demand_constant_elasticity(p, base_price, base_demand, elasticity)
else:
d = calculate_demand_linear(p, base_price, base_demand, elasticity)
# 目标是最大化利润,即最小化 -利润
prof = (p - cost) * d
return -prof
res = minimize_scalar(profit_func, bounds=(cost, upper_bound), method='bounded')
refined_optimal_price = optimal_price
refined_max_profit = max_profit
if res.success:
opt_p = res.x
opt_prof = -res.fun
# 只有当结果更好时才采纳 (避免局部极小值)
if opt_prof >= max_profit:
refined_optimal_price = opt_p
refined_max_profit = opt_prof
# 计算最优价格下的最终需求
if model_type == 'constant':
opt_d = calculate_demand_constant_elasticity(refined_optimal_price, base_price, base_demand, elasticity)
else:
opt_d = calculate_demand_linear(refined_optimal_price, base_price, base_demand, elasticity)
return {
'results': results,
'optimal': {
'price': round(refined_optimal_price, 2),
'profit': round(refined_max_profit, 2),
'demand': round(opt_d, 2),
'margin': round(((refined_optimal_price - cost) / refined_optimal_price) * 100, 2) if refined_optimal_price > 0 else 0
}
}
@app.route('/api/analyze', methods=['POST'])
def analyze():
try:
data = request.json
if not data:
return jsonify({'success': False, 'error': '未提供数据'}), 400
# 获取参数,提供默认值
cost = float(data.get('cost', 10.0))
base_price = float(data.get('base_price', 20.0))
base_demand = float(data.get('base_demand', 100.0))
elasticity = float(data.get('elasticity', -1.5))
model_type = data.get('model_type', 'linear')
result = run_optimization(cost, base_price, base_demand, elasticity, model_type)
return jsonify({
'success': True,
'data': result['results'],
'optimal': result['optimal']
})
except Exception as e:
logger.error(f"Analysis Error: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@app.route('/api/upload', methods=['POST'])
def upload_file():
"""
处理 CSV/Excel 文件上传并批量分析
"""
try:
if 'file' not in request.files:
return jsonify({'success': False, 'error': '未找到上传文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'error': '未选择文件'}), 400
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 读取文件
try:
if filename.endswith('.csv'):
df = pd.read_csv(filepath)
elif filename.endswith(('.xls', '.xlsx')):
df = pd.read_excel(filepath)
else:
return jsonify({'success': False, 'error': '不支持的文件格式 (仅支持 .csv, .xls, .xlsx)'}), 400
except Exception as e:
return jsonify({'success': False, 'error': f'文件读取错误: {str(e)}'}), 400
# 检查必需列
required_cols = ['cost', 'base_price', 'base_demand']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
return jsonify({'success': False, 'error': f'Missing columns: {", ".join(missing_cols)}'}), 400
# 批量处理
results = []
for _, row in df.iterrows():
try:
cost = float(row['cost'])
base_price = float(row['base_price'])
base_demand = float(row['base_demand'])
elasticity = float(row.get('elasticity', -1.5))
model_type = str(row.get('model_type', 'linear'))
opt_res = run_optimization(cost, base_price, base_demand, elasticity, model_type)
item_res = {
'input': {
'cost': cost,
'base_price': base_price,
'base_demand': base_demand,
'elasticity': elasticity
},
'optimal': opt_res['optimal']
}
results.append(item_res)
except Exception as row_err:
logger.warning(f"Row processing error: {row_err}")
continue
# 清理文件
try:
os.remove(filepath)
except:
pass
return jsonify({
'success': True,
'batch_results': results[:50] # 限制返回数量防止过大
})
except Exception as e:
logger.error(f"Upload Error: {str(e)}")
return jsonify({'success': False, 'error': f"服务器内部错误: {str(e)}"}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)