Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
| 7 |
import torch
|
| 8 |
import psycopg2
|
| 9 |
from datetime import datetime
|
|
|
|
| 10 |
|
| 11 |
app = Flask(__name__)
|
| 12 |
CORS(app)
|
|
@@ -19,16 +20,33 @@ logger = logging.getLogger(__name__)
|
|
| 19 |
MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
|
| 20 |
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
# سكيما قاعدة البيانات
|
| 32 |
DB_SCHEMA = """
|
| 33 |
CREATE TABLE public.biodata (
|
| 34 |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
@@ -110,30 +128,61 @@ CREATE TABLE public.user_place (
|
|
| 110 |
);
|
| 111 |
""".strip()
|
| 112 |
|
| 113 |
-
def
|
| 114 |
"""
|
| 115 |
-
ت
|
| 116 |
"""
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
@app.route('/api/query', methods=['POST'])
|
| 132 |
def handle_query():
|
|
|
|
|
|
|
|
|
|
| 133 |
try:
|
| 134 |
data = request.get_json()
|
| 135 |
|
| 136 |
-
# التحقق من البيانات المدخلة
|
| 137 |
if not all(k in data for k in ['text', 'cam_mac']):
|
| 138 |
return jsonify({"error": "المعطيات ناقصة"}), 400
|
| 139 |
|
|
@@ -146,31 +195,45 @@ def handle_query():
|
|
| 146 |
1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
|
| 147 |
2. مسموح فقط باستخدام SELECT
|
| 148 |
3. الجداول المتاحة: profiles, data, place
|
| 149 |
-
|
| 150 |
-
المثال:
|
| 151 |
-
السؤال: "ما عدد زياراتي؟"
|
| 152 |
-
SQL: SELECT COUNT(*) FROM data WHERE cam_mac = '{data['cam_mac']}';
|
| 153 |
"""
|
| 154 |
# توليد الاستعلام
|
| 155 |
sql = generate_sql(prompt, data['cam_mac'])
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
# تنفيذ الاستعلام
|
| 162 |
-
with psycopg2.connect(SUPABASE_DB_URL) as conn:
|
| 163 |
with conn.cursor() as cursor:
|
| 164 |
cursor.execute(sql)
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
return jsonify({
|
| 169 |
-
"data":
|
| 170 |
"sql": sql,
|
| 171 |
"timestamp": datetime.now().isoformat()
|
| 172 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
except Exception as e:
|
| 175 |
-
logger.error(f"
|
| 176 |
-
return jsonify({"error": "حدث خطأ في المعالجة"}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
import psycopg2
|
| 9 |
from datetime import datetime
|
| 10 |
+
from psycopg2 import pool
|
| 11 |
|
| 12 |
app = Flask(__name__)
|
| 13 |
CORS(app)
|
|
|
|
| 20 |
MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
|
| 21 |
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
|
| 22 |
|
| 23 |
+
# تهيئة connection pool
|
| 24 |
+
connection_pool = None
|
| 25 |
+
try:
|
| 26 |
+
connection_pool = psycopg2.pool.SimpleConnectionPool(
|
| 27 |
+
minconn=1,
|
| 28 |
+
maxconn=5,
|
| 29 |
+
dsn=SUPABASE_DB_URL
|
| 30 |
+
)
|
| 31 |
+
logger.info("تم إنشاء connection pool بنجاح")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"خطأ في إنشاء connection pool: {str(e)}")
|
| 34 |
|
| 35 |
+
# تحميل النموذج مرة واحدة عند بدء التشغيل
|
| 36 |
+
try:
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 38 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 39 |
+
|
| 40 |
+
# استخدام GPU إذا كان متاحًا، وإلا استخدام CPU
|
| 41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
model.to(device)
|
| 43 |
+
model.eval()
|
| 44 |
+
logger.info("تم تحميل النموذج بنجاح على الجهاز: %s", device)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"خطأ في تحميل النموذج: {str(e)}")
|
| 47 |
+
raise
|
| 48 |
|
| 49 |
+
# سكيما قاعدة البيانات (مختصرة لتحسين الأداء)
|
| 50 |
DB_SCHEMA = """
|
| 51 |
CREATE TABLE public.biodata (
|
| 52 |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
|
|
| 128 |
);
|
| 129 |
""".strip()
|
| 130 |
|
| 131 |
+
def clean_sql(sql: str, cam_mac: str) -> str:
|
| 132 |
"""
|
| 133 |
+
تنظيف استعلام SQL وإضافة شروط الأمان
|
| 134 |
"""
|
| 135 |
+
# إزالة أي أوامر غير مسموح بها
|
| 136 |
+
forbidden_keywords = ['insert', 'update', 'delete', 'drop', 'alter', 'create', 'truncate']
|
| 137 |
+
for keyword in forbidden_keywords:
|
| 138 |
+
if keyword in sql.lower():
|
| 139 |
+
raise ValueError(f"استعلام غير مسموح به يحتوي على {keyword}")
|
| 140 |
|
| 141 |
+
# التأكد من وجود شرط cam_mac
|
| 142 |
+
if 'where' in sql.lower():
|
| 143 |
+
sql = re.sub(r'where\s+', f"WHERE cam_mac = '{cam_mac}' AND ", sql, flags=re.IGNORECASE)
|
| 144 |
+
else:
|
| 145 |
+
if ';' in sql:
|
| 146 |
+
sql = sql.replace(';', f" WHERE cam_mac = '{cam_mac}';")
|
| 147 |
+
else:
|
| 148 |
+
sql += f" WHERE cam_mac = '{cam_mac}'"
|
| 149 |
|
| 150 |
+
# التأكد من أن الاستعلام يبدأ بـ SELECT فقط
|
| 151 |
+
if not sql.strip().lower().startswith('select'):
|
| 152 |
+
raise ValueError("يسمح فقط باستعلامات SELECT")
|
| 153 |
+
|
| 154 |
+
return sql
|
| 155 |
+
|
| 156 |
+
def generate_sql(prompt: str, cam_mac: str) -> str:
|
| 157 |
+
"""
|
| 158 |
+
توليد استعلام SQL من النص باستخدام النموذج
|
| 159 |
+
"""
|
| 160 |
+
try:
|
| 161 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
outputs = model.generate(
|
| 165 |
+
**inputs,
|
| 166 |
+
max_length=256,
|
| 167 |
+
num_beams=4,
|
| 168 |
+
early_stopping=True,
|
| 169 |
+
temperature=0.7
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 173 |
+
return clean_sql(sql, cam_mac)
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.error(f"خطأ في توليد SQL: {str(e)}")
|
| 176 |
+
return f"SELECT * FROM data WHERE cam_mac = '{cam_mac}' LIMIT 10;"
|
| 177 |
|
| 178 |
@app.route('/api/query', methods=['POST'])
|
| 179 |
def handle_query():
|
| 180 |
+
if not connection_pool:
|
| 181 |
+
return jsonify({"error": "لا يوجد اتصال بقاعدة البيانات"}), 500
|
| 182 |
+
|
| 183 |
try:
|
| 184 |
data = request.get_json()
|
| 185 |
|
|
|
|
| 186 |
if not all(k in data for k in ['text', 'cam_mac']):
|
| 187 |
return jsonify({"error": "المعطيات ناقصة"}), 400
|
| 188 |
|
|
|
|
| 195 |
1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
|
| 196 |
2. مسموح فقط باستخدام SELECT
|
| 197 |
3. الجداول المتاحة: profiles, data, place
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
"""
|
| 199 |
# توليد الاستعلام
|
| 200 |
sql = generate_sql(prompt, data['cam_mac'])
|
| 201 |
|
| 202 |
+
# تنفيذ الاستعلام باستخدام connection pool
|
| 203 |
+
conn = None
|
| 204 |
+
try:
|
| 205 |
+
conn = connection_pool.getconn()
|
|
|
|
|
|
|
| 206 |
with conn.cursor() as cursor:
|
| 207 |
cursor.execute(sql)
|
| 208 |
+
|
| 209 |
+
# إذا كان الاستعلام لا يعيد بيانات (مثل COUNT)
|
| 210 |
+
if cursor.description:
|
| 211 |
+
columns = [desc[0] for desc in cursor.description]
|
| 212 |
+
rows = cursor.fetchall()
|
| 213 |
+
result = [dict(zip(columns, row)) for row in rows]
|
| 214 |
+
else:
|
| 215 |
+
result = {"message": "تم تنفيذ الاستعلام بنجاح"}
|
| 216 |
|
| 217 |
return jsonify({
|
| 218 |
+
"data": result,
|
| 219 |
"sql": sql,
|
| 220 |
"timestamp": datetime.now().isoformat()
|
| 221 |
})
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"خطأ في قاعدة البيانات: {str(e)}")
|
| 224 |
+
return jsonify({"error": "حدث خطأ في معالجة الاستعلام"}), 500
|
| 225 |
+
finally:
|
| 226 |
+
if conn:
|
| 227 |
+
connection_pool.putconn(conn)
|
| 228 |
|
| 229 |
except Exception as e:
|
| 230 |
+
logger.error(f"خطأ عام: {str(e)}")
|
| 231 |
+
return jsonify({"error": "حدث خطأ في المعالجة"}), 500
|
| 232 |
+
|
| 233 |
+
@app.route('/health', methods=['GET'])
|
| 234 |
+
def health_check():
|
| 235 |
+
return jsonify({"status": "healthy", "timestamp": datetime.now().isoformat()})
|
| 236 |
+
|
| 237 |
+
if __name__ == '__main__':
|
| 238 |
+
port = int(os.environ.get('PORT', 8080))
|
| 239 |
+
app.run(host='0.0.0.0', port=port)
|