Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from datetime import datetime, timedelta | |
| from flask import Flask, request, render_template, jsonify, Response, send_from_directory, session | |
| from flask_cors import CORS | |
| import uuid | |
| import debug | |
| import logging | |
| import time | |
| import copy | |
| import google.generativeai as genai | |
| from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
| from werkzeug.utils import secure_filename | |
| app = Flask(__name__) | |
| CORS(app) | |
| app.secret_key = '112edsafbfgd wrfgsdbdawe' | |
| CHAT_HISTORY_DIR = '/tmp/chat_histories' | |
| PRESETS_DIR = '/tmp/presets' | |
| os.makedirs(PRESETS_DIR, exist_ok=True) | |
| os.makedirs(CHAT_HISTORY_DIR, exist_ok=True) | |
| UPLOAD_FOLDER = '/tmp/uploads' | |
| if not os.path.exists(UPLOAD_FOLDER): | |
| os.makedirs(UPLOAD_FOLDER) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| app.config.update( | |
| SESSION_COOKIE_SECURE=True, | |
| PERMANENT_SESSION_LIFETIME=timedelta(days=31), # 设置session有效期 | |
| SESSION_COOKIE_HTTPONLY=True | |
| ) | |
| def make_session_permanent(): | |
| session.permanent = True # 使session持久化 | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| safety_settings={ | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| } | |
| PREDEFINED_PRESETS = [ | |
| { | |
| 'id': 'default', | |
| 'name':'默认', | |
| 'content': "You are a helpful assistant." | |
| }, | |
| { | |
| 'id':'article_editing', | |
| 'name':'文章润色', | |
| 'content': "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations."}, | |
| { | |
| 'id':'translation', | |
| 'name':'翻译', | |
| 'content': "我希望你能担任翻译、拼写校对和修辞改进的角色。我会用任何语言和你交流,你会识别语言,将其翻译并用更为优美和精炼的中文或英文回答我。请将我简单的词汇和句子替换成更为优美和高雅的表达方式,确保意思不变,但使其更具文学性。请仅回答更正和改进的部分,不要写解释。" | |
| }, | |
| { | |
| 'id':'discussion', | |
| 'name':'辩论', | |
| 'content': "I want you to act as a debater. I will provide you with some topics related to current events and your task is to research both sides of the debates, present valid arguments for each side, refute opposing points of view, and draw persuasive conclusions based on evidence. Your goal is to help people come away from the discussion with increased knowledge and insight into the topic at hand. The entire conversation and instructions should be provided in Chinese." | |
| }, | |
| { | |
| 'id':'regex_generator', | |
| 'name':'正则表达式生成器', | |
| 'content': "I want you to act as a regex generator. Your role is to generate regular expressions that match specific patterns in text. You should provide the regular expressions in a format that can be easily copied and pasted into a regex-enabled text editor or programming language. Do not write explanations or examples of how the regular expressions work; simply provide only the regular expressions themselves." | |
| }, | |
| { | |
| 'id':'front_end_developer', | |
| 'name':'前端开发', | |
| 'content': "I want you to act as a Senior Frontend developer. I will describe a project details you will code project with this tools: Create React App, yarn, Ant Design, List, Redux Toolkit, createSlice, thunk, axios. You should merge files in single index.js file and nothing else. Do not write explanations." | |
| } | |
| ] | |
| class APIKeyManager: | |
| def __init__(self): | |
| self.api_keys = os.environ.get('API_KEYS').split(',') | |
| self.daily_uses = {key: 0 for key in self.api_keys} | |
| self.last_reset = datetime.now().date() | |
| self.current_index = 0 # 添加索引来追踪当前使用的 key | |
| def reset_daily_uses(self): | |
| today = datetime.now().date() | |
| if today > self.last_reset: | |
| self.daily_uses = {key: 0 for key in self.api_keys} | |
| self.last_reset = today | |
| def get_available_key(self): | |
| self.reset_daily_uses() | |
| # 检查所有 key 是否都达到限制 | |
| all_exhausted = all(self.daily_uses[key] >= 200 for key in self.api_keys) | |
| if all_exhausted: | |
| return None | |
| # 循环查找直到找到未达到限制的 key | |
| while True: | |
| # 获取当前 key | |
| current_key = self.api_keys[self.current_index] | |
| # 如果当前 key 未达到限制,使用它 | |
| if self.daily_uses[current_key] < 200: | |
| self.daily_uses[current_key] += 1 | |
| # 更新索引到下一个位置 | |
| self.current_index = (self.current_index + 1) % len(self.api_keys) | |
| return current_key | |
| # 如果当前 key 已达到限制,移动到下一个 | |
| self.current_index = (self.current_index + 1) % len(self.api_keys) | |
| key_manager = APIKeyManager() | |
| current_api_key = key_manager.get_available_key() | |
| def make_session_permanent(): | |
| session.permanent = True # 使session持久化 | |
| def init_session(): | |
| client_session_id = request.json.get('session_id') | |
| if client_session_id: | |
| return jsonify({'session_id': client_session_id}) | |
| new_session_id = str(uuid.uuid4()) | |
| return jsonify({'session_id': new_session_id}) | |
| def get_or_create_session_id(): | |
| data = request.get_json() | |
| session_id = data.get('session_id') | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| return session_id | |
| def get_chat_history_path(session_id): | |
| return os.path.join(CHAT_HISTORY_DIR, f'{session_id}.json') | |
| """加载聊天历史记录""" | |
| def load_chat_history(session_id): | |
| history_path = get_chat_history_path(session_id) | |
| try: | |
| if os.path.exists(history_path): | |
| with open(history_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading chat history: {e}") | |
| return [] | |
| """保存聊天历史记录""" | |
| def save_chat_history(session_id, history): | |
| history_path = get_chat_history_path(session_id) | |
| try: | |
| with open(history_path, 'w', encoding='utf-8') as f: | |
| json.dump(history, f, ensure_ascii=False, indent=2) | |
| debug.log_message(f"Saved chat history to {history_path}") | |
| except Exception as e: | |
| print(f"Error saving chat history: {e}") | |
| def upload_to_gemini(path, mime_type=None): | |
| """上传文件到Gemini API""" | |
| genai.configure(api_key=current_api_key) | |
| file = genai.upload_file(path, mime_type=mime_type) | |
| debug.log_message(f"Uploaded file '{file.display_name}' as: {file.uri}") | |
| return file | |
| def wait_for_files_active(files): | |
| """等待文件处理完成""" | |
| genai.configure(api_key=current_api_key) | |
| debug.log_message("Waiting for file processing...") | |
| for name in (file.name for file in files): | |
| file = genai.get_file(name) | |
| while file.state.name == "PROCESSING": | |
| debug.log_message(".", end="", flush=True) | |
| time.sleep(10) | |
| file = genai.get_file(name) | |
| if file.state.name != "ACTIVE": | |
| raise Exception(f"File {file.name} failed to process") | |
| debug.log_message(f"File {file.name} processed successfully") | |
| def get_upload_file(): | |
| debug.log_message(f"upload request with api_key: {current_api_key}") | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file part'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| if file: | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # 获取MIME类型 | |
| mime_type = file.content_type | |
| # 上传到Gemini | |
| try: | |
| gemini_file = upload_to_gemini(filepath, mime_type=mime_type) | |
| wait_for_files_active([gemini_file]) | |
| return jsonify({ | |
| 'success': True, | |
| 'filename': filename, | |
| 'gemini_uri': gemini_file.uri | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error uploading file: {str(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| def get_presets(): | |
| presets = [] | |
| for filename in os.listdir(PRESETS_DIR): | |
| if filename.endswith('.json'): | |
| with open(os.path.join(PRESETS_DIR, filename), 'r', encoding='utf-8') as f: | |
| preset = json.load(f) | |
| presets.append(preset) | |
| return jsonify(PREDEFINED_PRESETS + presets) | |
| def add_preset(): | |
| data = request.get_json() | |
| if not data or 'name' not in data or 'content' not in data: | |
| return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400 | |
| preset_id = str(uuid.uuid4()) | |
| preset = { | |
| 'id': preset_id, | |
| 'name': data['name'], | |
| 'content': data['content'] | |
| } | |
| with open(os.path.join(PRESETS_DIR, f'{preset_id}.json'), 'w', encoding='utf-8') as f: | |
| json.dump(preset, f, ensure_ascii=False, indent=2) | |
| return jsonify({'status': 'success', 'id': preset_id}) | |
| def delete_preset(): | |
| data = request.get_json() | |
| if not data or 'id' not in data: | |
| return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400 | |
| preset_id = data['id'] | |
| # 检查是否为预定义预设 | |
| for preset in PREDEFINED_PRESETS: | |
| if preset['id'] == preset_id: | |
| return jsonify({'status': 'error', 'error': 'Cannot delete predefined preset'}), 403 | |
| preset_file = os.path.join(PRESETS_DIR, f'{preset_id}.json') | |
| if os.path.exists(preset_file): | |
| os.remove(preset_file) | |
| return jsonify({'status': 'success'}) | |
| return jsonify({'status': 'error', 'error': 'Preset not found'}), 404 | |
| def index(): | |
| return render_template('index.html') | |
| def chat(): | |
| data = request.get_json() | |
| if not data or 'userMessage' not in data or 'preset' not in data or 'session_id' not in data: | |
| return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400 | |
| userMessage = data['userMessage'] | |
| preset_id = data['preset'] | |
| session_id = data['session_id'] | |
| chat_history = load_chat_history(session_id) | |
| # 去除10条之前的历史记录 | |
| if len(chat_history) > 10: | |
| chat_history = chat_history[-10:] | |
| system_instruction = "" | |
| preset_name = None | |
| for preset in PREDEFINED_PRESETS: | |
| if preset['id'] == preset_id: | |
| system_instruction = preset['content'] | |
| preset_name = preset['name'] | |
| break | |
| if system_instruction == "": | |
| preset_file = os.path.join(PRESETS_DIR, f'{preset_id}.json') | |
| if os.path.exists(preset_file): | |
| with open(preset_file, 'r', encoding='utf-8') as f: | |
| preset_data = json.load(f) | |
| system_instruction = preset_data['content'] | |
| preset_name = preset_data['name'] | |
| else: | |
| default_preset = next(p for p in PREDEFINED_PRESETS if p['id'] == 'default') | |
| system_instruction = default_preset['content'] | |
| preset_name = default_preset['name'] | |
| def generate(): | |
| global current_api_key | |
| debug.log_message(f"chat_api_key: {current_api_key}") | |
| if not current_api_key: | |
| yield "data: {\"error\": \"No available API keys.\"}\n\n" | |
| return | |
| genai.configure(api_key=current_api_key) | |
| parts = userMessage.get("parts", []) | |
| userTextMessage = parts[-1] | |
| processing_files = parts[:-1] | |
| if processing_files: | |
| chat_history.append({"role": "user", "parts": processing_files}) | |
| send_chat_history = copy.deepcopy(chat_history) | |
| for message in send_chat_history: | |
| if message["role"] == "user": | |
| parts = message["parts"] | |
| for i, part in enumerate(parts): | |
| if isinstance(part, dict) and "uri" in part: | |
| file_id = part["uri"].split("/")[-1] | |
| file_info = genai.get_file(file_id) | |
| parts[i] = file_info | |
| message["parts"] = parts | |
| generation_config = { | |
| "temperature": 1, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_output_tokens": 8192, | |
| "response_mime_type": "text/plain", | |
| } | |
| model = genai.GenerativeModel( | |
| model_name="gemini-exp-1206", | |
| generation_config=generation_config, | |
| system_instruction=system_instruction, | |
| safety_settings=safety_settings, | |
| ) | |
| model_message = {"role": "model", "parts": [""]} | |
| try: | |
| chat_session = model.start_chat(history=send_chat_history) | |
| response = chat_session.send_message(userTextMessage, stream=True) | |
| for chunk in response: | |
| if chunk.text: | |
| model_message['parts'][0] += chunk.text | |
| yield f"{chunk.text}" | |
| if processing_files: | |
| chat_history[-1]["parts"].append(userTextMessage) | |
| else: | |
| chat_history.append({"role": "user", "parts": [userTextMessage]}) | |
| chat_history.append(model_message) | |
| save_chat_history(session_id, chat_history) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| current_api_key = key_manager.get_available_key() | |
| debug.log_message(f"switched to api_key: {current_api_key}") | |
| return Response(generate(), mimetype='text/event-stream') | |
| def get_history(): | |
| """获取聊天历史记录""" | |
| session_id = request.args.get('session_id') | |
| history = load_chat_history(session_id) | |
| return jsonify(history) | |
| def clear_history(): | |
| """清除聊天历史记录""" | |
| session_id = request.args.get('session_id') | |
| try: | |
| os.remove(get_chat_history_path(session_id)) | |
| return jsonify({"status": "success"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}) | |
| def switch_models(): | |
| """切换模型""" | |
| data = request.get_json() | |
| if not data or 'model' not in data: | |
| return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400 | |
| model = data['model'] | |
| session['model'] = model | |
| return jsonify({'status': 'success', 'model': model}) | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) |