Custom-Gemini / app.py
Moonfanz's picture
Upload app.py
0d41f3c verified
raw
history blame
14.9 kB
import os
import json
from datetime import datetime
from flask import Flask, request, render_template, jsonify, Response, send_from_directory, session
from flask_cors import CORS
import uuid
import debug
import logging
import time
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from werkzeug.utils import secure_filename
app = Flask(__name__)
CORS(app)
app.secret_key = os.urandom(24)
CHAT_HISTORY_DIR = '/tmp/chat_histories'
PRESETS_DIR = '/tmp/presets'
os.makedirs(PRESETS_DIR, exist_ok=True)
os.makedirs(CHAT_HISTORY_DIR, exist_ok=True)
UPLOAD_FOLDER = '/tmp/uploads'
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
PREDEFINED_PRESETS = [
{
'id': 'default',
'name':'默认',
'content': "You are a helpful assistant."
},
{
'id':'article_editing',
'name':'文章润色',
'content': "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations."},
{
'id':'translation',
'name':'翻译',
'content': "我希望你能担任翻译、拼写校对和修辞改进的角色。我会用任何语言和你交流,你会识别语言,将其翻译并用更为优美和精炼的中文或英文回答我。请将我简单的词汇和句子替换成更为优美和高雅的表达方式,确保意思不变,但使其更具文学性。请仅回答更正和改进的部分,不要写解释。"
},
{
'id':'discussion',
'name':'辩论',
'content': "I want you to act as a debater. I will provide you with some topics related to current events and your task is to research both sides of the debates, present valid arguments for each side, refute opposing points of view, and draw persuasive conclusions based on evidence. Your goal is to help people come away from the discussion with increased knowledge and insight into the topic at hand. The entire conversation and instructions should be provided in Chinese."
},
{
'id':'regex_generator',
'name':'正则表达式生成器',
'content': "I want you to act as a regex generator. Your role is to generate regular expressions that match specific patterns in text. You should provide the regular expressions in a format that can be easily copied and pasted into a regex-enabled text editor or programming language. Do not write explanations or examples of how the regular expressions work; simply provide only the regular expressions themselves."
},
{
'id':'front_end_developer',
'name':'前端开发',
'content': "I want you to act as a Senior Frontend developer. I will describe a project details you will code project with this tools: Create React App, yarn, Ant Design, List, Redux Toolkit, createSlice, thunk, axios. You should merge files in single index.js file and nothing else. Do not write explanations."
}
]
class APIKeyManager:
def __init__(self):
self.api_keys = 'AIzaSyDcXEgE6qFDcFrTThzDOX0vPmSdeg1SbaA'.split(',')
self.daily_uses = {key: 0 for key in self.api_keys}
self.last_reset = datetime.now().date()
def reset_daily_uses(self):
today = datetime.now().date()
if today > self.last_reset:
self.daily_uses = {key: 0 for key in self.api_keys}
self.last_reset = today
def get_available_key(self):
self.reset_daily_uses()
for key in self.api_keys:
if self.daily_uses[key] < 200:
self.daily_uses[key] += 1
return key
return None
key_manager = APIKeyManager()
genai.configure(api_key=key_manager.get_available_key())
def get_or_create_session_id():
if 'session_id' not in session:
session['session_id'] = str(uuid.uuid4())
return session['session_id']
def get_chat_history_path(session_id):
return os.path.join(CHAT_HISTORY_DIR, f'{session_id}.json')
"""加载聊天历史记录"""
def load_chat_history(session_id):
history_path = get_chat_history_path(session_id)
try:
if os.path.exists(history_path):
with open(history_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Error loading chat history: {e}")
return []
"""保存聊天历史记录"""
def save_chat_history(session_id, history):
history_path = get_chat_history_path(session_id)
try:
with open(history_path, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
debug.log_message(f"Saved chat history to {history_path}")
except Exception as e:
print(f"Error saving chat history: {e}")
def upload_to_gemini(path, mime_type=None):
"""上传文件到Gemini API"""
file = genai.upload_file(path, mime_type=mime_type)
logger.debug(f"Uploaded file '{file.display_name}' as: {file.uri}")
return file
def wait_for_files_active(files):
"""等待文件处理完成"""
logger.debug("Waiting for file processing...")
for name in (file.name for file in files):
file = genai.get_file(name)
while file.state.name == "PROCESSING":
logger.debug(".", end="", flush=True)
time.sleep(10)
file = genai.get_file(name)
if file.state.name != "ACTIVE":
raise Exception(f"File {file.name} failed to process")
logger.debug("...all files ready")
@app.route('/upload', methods=['POST'])
def get_upload_file():
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 获取MIME类型
mime_type = file.content_type
# 上传到Gemini
try:
gemini_file = upload_to_gemini(filepath, mime_type=mime_type)
wait_for_files_active([gemini_file])
return jsonify({
'success': True,
'filename': filename,
'gemini_uri': gemini_file.uri
})
except Exception as e:
logger.error(f"Error uploading file: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/presets', methods=['GET'])
def get_presets():
presets = []
for filename in os.listdir(PRESETS_DIR):
if filename.endswith('.json'):
with open(os.path.join(PRESETS_DIR, filename), 'r', encoding='utf-8') as f:
preset = json.load(f)
presets.append(preset)
return jsonify(PREDEFINED_PRESETS + presets)
@app.route('/add_preset', methods=['POST'])
def add_preset():
data = request.get_json()
if not data or 'name' not in data or 'content' not in data:
return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400
preset_id = str(uuid.uuid4())
preset = {
'id': preset_id,
'name': data['name'],
'content': data['content']
}
with open(os.path.join(PRESETS_DIR, f'{preset_id}.json'), 'w', encoding='utf-8') as f:
json.dump(preset, f, ensure_ascii=False, indent=2)
return jsonify({'status': 'success', 'id': preset_id})
@app.route('/delete_preset', methods=['POST'])
def delete_preset():
data = request.get_json()
if not data or 'id' not in data:
return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400
preset_id = data['id']
# 检查是否为预定义预设
for preset in PREDEFINED_PRESETS:
if preset['id'] == preset_id:
return jsonify({'status': 'error', 'error': 'Cannot delete predefined preset'}), 403
preset_file = os.path.join(PRESETS_DIR, f'{preset_id}.json')
if os.path.exists(preset_file):
os.remove(preset_file)
return jsonify({'status': 'success'})
return jsonify({'status': 'error', 'error': 'Preset not found'}), 404
@app.route('/upload', methods=['POST'])
def upload_file():
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 获取MIME类型
mime_type = file.content_type
# 上传到Gemini
try:
gemini_file = upload_to_gemini(filepath, mime_type=mime_type)
wait_for_files_active([gemini_file])
return jsonify({
'success': True,
'filename': filename,
'gemini_uri': gemini_file.uri
})
except Exception as e:
logger.error(f"Error uploading file: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/')
def index():
return render_template('index.html')
@app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
if not data or 'userMessage' not in data or 'preset' not in data:
return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400
userMessage = data['userMessage']
print(f"User message: {userMessage}")
preset_id = data['preset']
session_id = get_or_create_session_id()
chat_history = load_chat_history(session_id)
system_instruction = ""
preset_name = None
for preset in PREDEFINED_PRESETS:
if preset['id'] == preset_id:
system_instruction = preset['content']
preset_name = preset['name']
break
if system_instruction is None:
preset_file = os.path.join(PRESETS_DIR, f'{preset_id}.json')
if os.path.exists(preset_file):
with open(preset_file, 'r', encoding='utf-8') as f:
preset_data = json.load(f)
system_instruction = preset_data['content']
preset_name = preset_data['name']
else:
default_preset = next(p for p in PREDEFINED_PRESETS if p['id'] == 'default')
system_instruction = default_preset['content']
preset_name = default_preset['name']
def generate():
key = key_manager.get_available_key()
if not key:
yield "data: {\"error\": \"No available API keys.\"}\n\n"
return
files = []
parts = userMessage.get("parts", [])
logger.debug(f"Parts: {parts}")
userTextMessage = parts[-1]
file_urls = parts[:-1]
if file_urls:
for file_uri in file_urls:
file_id = file_uri.split('/')[-1]
file = genai.get_file(file_id)
files.append(file)
chat_history.append({"role": "user", "parts": files})
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
}
model = genai.GenerativeModel(
model_name="gemini-1.5-pro-002",
generation_config=generation_config,
system_instruction=system_instruction,
safety_settings=safety_settings,
)
model_message = {"role": "model", "parts": [""]}
try:
chat_session = model.start_chat(history=chat_history)
response = chat_session.send_message(userTextMessage, stream=True)
for chunk in response:
if chunk.text:
model_message['parts'][0] += chunk.text
yield f"{chunk.text}"
if file_urls:
chat_history[-1]["parts"].append(userTextMessage)
else:
chat_history.append({"role": "user", "parts": [userTextMessage]})
chat_history.append(model_message)
save_chat_history(session_id, chat_history)
debug.log_prompt(chat_history, preset_name)
logger.debug(f"User message parts: {parts}")
logger.debug(f"Chat history before sending: {chat_history}")
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return Response(generate(), mimetype='text/event-stream')
@app.route('/history', methods=['GET'])
def get_history():
"""获取聊天历史记录"""
session_id = get_or_create_session_id()
history = load_chat_history(session_id)
return jsonify(history)
@app.route('/clear_history', methods=['POST'])
def clear_history():
"""清除聊天历史记录"""
session_id = get_or_create_session_id()
try:
os.remove(get_chat_history_path(session_id))
return jsonify({"status": "success"})
except Exception as e:
return jsonify({"status": "error", "message": str(e)})
@app.route('/switch_models', methods=['POST'])
def switch_models():
"""切换模型"""
data = request.get_json()
if not data or 'model' not in data:
return jsonify({'status': 'error', 'error': 'Invalid request data'}), 400
model = data['model']
session['model'] = model
return jsonify({'status': 'success', 'model': model})
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))