|
|
""" |
|
|
Backend API cho HT_MATH_WEB - Chạy trên Hugging Face Spaces (Docker Version) |
|
|
Phiên bản: 9.1 (Strict OCR - Anti-Hallucination) |
|
|
Tác giả: Hoàng Tấn Thiên |
|
|
""" |
|
|
|
|
|
import os |
|
|
import io |
|
|
import time |
|
|
import asyncio |
|
|
import re |
|
|
import tempfile |
|
|
import hashlib |
|
|
import secrets |
|
|
import uuid |
|
|
import math |
|
|
from typing import List, Optional |
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse, FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from PIL import Image |
|
|
import fitz |
|
|
import google.generativeai as genai |
|
|
import json |
|
|
|
|
|
|
|
|
try: |
|
|
import pypandoc |
|
|
print(f"INFO: Pandoc version detected: {pypandoc.get_pandoc_version()}") |
|
|
except ImportError: |
|
|
print("CRITICAL WARNING: pypandoc module not found.") |
|
|
except OSError: |
|
|
print("CRITICAL WARNING: pandoc binary not found in system path.") |
|
|
|
|
|
|
|
|
try: |
|
|
import pytesseract |
|
|
|
|
|
|
|
|
print("INFO: Tesseract OCR module loaded.") |
|
|
except ImportError: |
|
|
print("WARNING: pytesseract not found. Fallback OCR will not work.") |
|
|
pytesseract = None |
|
|
|
|
|
|
|
|
try: |
|
|
import firebase_admin |
|
|
from firebase_admin import credentials, firestore |
|
|
FIREBASE_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("WARNING: firebase-admin not found. Authentication will not work.") |
|
|
FIREBASE_AVAILABLE = False |
|
|
firebase_admin = None |
|
|
credentials = None |
|
|
firestore = None |
|
|
|
|
|
|
|
|
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") |
|
|
GEMINI_MODELS = os.getenv("GEMINI_MODELS", "gemini-2.5-flash,gemini-1.5-pro").split(",") |
|
|
FIREBASE_CREDENTIALS = os.getenv("FIREBASE_CREDENTIALS", "") |
|
|
MAX_THREADS = int(os.getenv("MAX_THREADS", "5")) |
|
|
ADMIN_SECRET_KEY = os.getenv("ADMIN_SECRET_KEY", "admin123") |
|
|
|
|
|
|
|
|
db = None |
|
|
if FIREBASE_AVAILABLE and FIREBASE_CREDENTIALS: |
|
|
try: |
|
|
|
|
|
cred_dict = json.loads(FIREBASE_CREDENTIALS) |
|
|
cred = credentials.Certificate(cred_dict) |
|
|
firebase_admin.initialize_app(cred) |
|
|
db = firestore.client() |
|
|
print("INFO: Firebase Firestore connected successfully") |
|
|
except Exception as e: |
|
|
print(f"Warning: Không thể kết nối Firebase: {e}") |
|
|
else: |
|
|
if not FIREBASE_AVAILABLE: |
|
|
print("Warning: firebase-admin package not installed") |
|
|
if not FIREBASE_CREDENTIALS: |
|
|
print("Warning: FIREBASE_CREDENTIALS environment variable not set") |
|
|
|
|
|
app = FastAPI(title="HT_MATH_WEB API", version="9.1") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads") |
|
|
|
|
|
@app.exception_handler(404) |
|
|
async def not_found_handler(request, exc): |
|
|
return JSONResponse( |
|
|
status_code=404, |
|
|
content={ |
|
|
"detail": f"Route not found: {request.url.path}", |
|
|
"available_routes": ["/", "/api/models", "/api/convert", "/api/export-docx", "/api/login", "/api/check-session", "/api/upload-image"] |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
class ApiKeyManager: |
|
|
def __init__(self, keys: List[str]): |
|
|
self.api_keys = [k.strip() for k in keys if k.strip()] |
|
|
self.current_index = 0 |
|
|
|
|
|
def get_next_key(self) -> Optional[str]: |
|
|
if not self.api_keys: return None |
|
|
key = self.api_keys[self.current_index] |
|
|
self.current_index = (self.current_index + 1) % len(self.api_keys) |
|
|
return key |
|
|
|
|
|
def get_key_count(self) -> int: |
|
|
return len(self.api_keys) |
|
|
|
|
|
key_manager = ApiKeyManager(GEMINI_API_KEYS) |
|
|
|
|
|
ip_rate_limits = {} |
|
|
RATE_LIMIT_DURATION = 7 |
|
|
|
|
|
def check_rate_limit(request: Request): |
|
|
forwarded = request.headers.get("X-Forwarded-For") |
|
|
client_ip = forwarded.split(",")[0].strip() if forwarded else request.client.host |
|
|
now = time.time() |
|
|
if client_ip in ip_rate_limits: |
|
|
elapsed = now - ip_rate_limits[client_ip] |
|
|
if elapsed < RATE_LIMIT_DURATION: |
|
|
print(f"[RateLimit] IP {client_ip} requesting too fast.") |
|
|
ip_rate_limits[client_ip] = now |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DIRECT_GEMINI_PROMPT_TEXT_ONLY = r"""Đóng vai một CÔNG CỤ OCR CHUYÊN DỤNG cho văn bản hành chính Việt Nam. |
|
|
NHIỆM VỤ: Trích xuất nguyên văn (Verbatim) nội dung trong ảnh. |
|
|
|
|
|
QUY TẮC BẤT DI BẤT DỊCH (CẤM VI PHẠM): |
|
|
1. TUYỆT ĐỐI KHÔNG thêm lời dẫn đầu (như "Dưới đây là...", "Chào bạn...", "Văn bản này nói về..."). |
|
|
2. TUYỆT ĐỐI KHÔNG thêm lời kết luận hay nhận xét, đánh giá. |
|
|
3. KHÔNG thay đổi từ ngữ, KHÔNG diễn giải (paraphrase), KHÔNG tóm tắt. Phải giữ nguyên văn kể cả lỗi chính tả của bản gốc. |
|
|
4. Giữ nguyên cấu trúc dòng, đoạn, số hiệu, ngày tháng, quốc hiệu, tiêu ngữ. |
|
|
5. Nếu chữ quá mờ không đọc được, hãy điền ký hiệu `[...]`. KHÔNG ĐƯỢC TỰ Ý ĐOÁN MÒ HAY BỊA ĐẶT NỘI DUNG. |
|
|
|
|
|
ĐẦU RA: Chỉ trả về văn bản thô. Không Markdown dư thừa nếu không cần thiết.""" |
|
|
|
|
|
DIRECT_GEMINI_PROMPT_LATEX = r"""Đóng vai công cụ số hóa tài liệu Toán học chính xác tuyệt đối. |
|
|
NHIỆM VỤ: Chuyển đổi ảnh sang Markdown + LaTeX. |
|
|
|
|
|
QUY TẮC: |
|
|
1. Công thức toán học PHẢI nằm trong dấu `$`. Ví dụ: $x^2 + 1 = 0$. |
|
|
2. KHÔNG thêm lời dẫn thừa (như "Kết quả là:", "Đây là bài giải:"). |
|
|
3. GIỮ NGUYÊN VĂN đề bài và lời giải. KHÔNG tóm tắt, KHÔNG viết lại theo văn phong khác. |
|
|
4. Nếu ảnh bị cắt hoặc mờ: |
|
|
- Nếu là công thức toán có thể suy luận logic: Hãy phục hồi. |
|
|
- Nếu là văn bản/tên riêng: Điền `[...]`. |
|
|
|
|
|
CHỈ TRẢ VỀ NỘI DUNG MARKDOWN.""" |
|
|
|
|
|
|
|
|
def stitch_text(text_a: str, text_b: str, min_overlap_chars: int = 20) -> str: |
|
|
if not text_a: return text_b |
|
|
if not text_b: return text_a |
|
|
|
|
|
a_lines = text_a.splitlines() |
|
|
b_lines = text_b.splitlines() |
|
|
|
|
|
scan_window = min(len(a_lines), len(b_lines), 30) |
|
|
best_overlap_idx = 0 |
|
|
|
|
|
for i in range(scan_window, 0, -1): |
|
|
tail_a = "\n".join(a_lines[-i:]).strip() |
|
|
head_b = "\n".join(b_lines[:i]).strip() |
|
|
|
|
|
if len(tail_a) >= min_overlap_chars and tail_a == head_b: |
|
|
best_overlap_idx = i |
|
|
break |
|
|
|
|
|
if best_overlap_idx > 0: |
|
|
return text_a + "\n" + "\n".join(b_lines[best_overlap_idx:]) |
|
|
else: |
|
|
return text_a + "\n\n" + text_b |
|
|
|
|
|
|
|
|
def clean_latex_formulas(text: str) -> str: |
|
|
|
|
|
text = re.sub(r'\$\s+(.*?)\s+\$', lambda m: f'${m.group(1).strip()}$', text) |
|
|
|
|
|
return text |
|
|
|
|
|
def hash_password(password: str) -> str: |
|
|
return hashlib.sha256(password.encode()).hexdigest() |
|
|
|
|
|
def verify_password(password: str, hashed: str) -> bool: |
|
|
return hash_password(password) == hashed |
|
|
|
|
|
def safe_get_text(response) -> str: |
|
|
""" |
|
|
Trích xuất text an toàn từ Gemini Response. |
|
|
Xử lý trường hợp bị chặn bản quyền (finish_reason=4). |
|
|
""" |
|
|
if not response.candidates: |
|
|
return "" |
|
|
|
|
|
candidate = response.candidates[0] |
|
|
|
|
|
|
|
|
|
|
|
if candidate.finish_reason == 4: |
|
|
return "[BLOCKED_BY_COPYRIGHT]" |
|
|
|
|
|
if candidate.finish_reason != 1: |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
parts = candidate.content.parts |
|
|
texts = [p.text for p in parts if hasattr(p, "text")] |
|
|
return "\n".join(texts) |
|
|
|
|
|
async def fallback_ocr_tesseract(image: Image.Image) -> str: |
|
|
""" |
|
|
Fallback dùng Tesseract OCR khi Gemini từ chối phục vụ |
|
|
""" |
|
|
if pytesseract is None: |
|
|
return "**[Lỗi] Gemini từ chối xử lý (Bản quyền) và Tesseract chưa được cài đặt.**" |
|
|
|
|
|
print("[Fallback] Đang chạy Tesseract OCR...") |
|
|
try: |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
text = await loop.run_in_executor(None, lambda: pytesseract.image_to_string(image, lang='vie+eng')) |
|
|
return f"**[Lưu ý: Nội dung này được trích xuất bằng OCR dự phòng do Gemini chặn bản quyền]**\n\n{text}" |
|
|
except Exception as e: |
|
|
print(f"Tesseract Error: {e}") |
|
|
return "**[Lỗi] Cả Gemini và Tesseract đều thất bại.**" |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
@app.get("/health") |
|
|
async def root(): |
|
|
pandoc_status = "Not Found" |
|
|
tesseract_status = "Not Found" |
|
|
try: |
|
|
pandoc_status = pypandoc.get_pandoc_version() |
|
|
except: pass |
|
|
|
|
|
try: |
|
|
if pytesseract: tesseract_status = "Ready" |
|
|
except: pass |
|
|
|
|
|
return { |
|
|
"status": "ok", |
|
|
"service": "HT_MATH_WEB API v9.1 (Strict Mode)", |
|
|
"keys_loaded": key_manager.get_key_count(), |
|
|
"pandoc": pandoc_status, |
|
|
"tesseract": tesseract_status |
|
|
} |
|
|
|
|
|
@app.get("/api/models") |
|
|
async def get_models(): |
|
|
return {"models": GEMINI_MODELS} |
|
|
|
|
|
|
|
|
@app.post("/api/register") |
|
|
async def register(email: str = Form(...), password: str = Form(...)): |
|
|
if not db: |
|
|
raise HTTPException(status_code=500, detail="DB Error") |
|
|
|
|
|
|
|
|
users_ref = db.collection('users') |
|
|
existing = list(users_ref.where('email', '==', email).limit(1).stream()) |
|
|
|
|
|
if existing: |
|
|
raise HTTPException(status_code=400, detail="Email tồn tại") |
|
|
|
|
|
|
|
|
user_data = { |
|
|
"email": email, |
|
|
"password": hash_password(password), |
|
|
"status": "pending", |
|
|
"created_at": time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
} |
|
|
users_ref.add(user_data) |
|
|
|
|
|
return {"success": True, "message": "Đăng ký thành công, chờ duyệt."} |
|
|
|
|
|
@app.post("/api/login") |
|
|
async def login(request: Request, email: str = Form(...), password: str = Form(...)): |
|
|
if not db: |
|
|
raise HTTPException(status_code=500, detail="DB Error") |
|
|
|
|
|
|
|
|
users_ref = db.collection('users') |
|
|
user_docs = list(users_ref.where('email', '==', email).limit(1).stream()) |
|
|
|
|
|
if not user_docs: |
|
|
raise HTTPException(status_code=401, detail="Sai email/pass") |
|
|
|
|
|
user_data = user_docs[0].to_dict() |
|
|
|
|
|
if not verify_password(password, user_data["password"]): |
|
|
raise HTTPException(status_code=401, detail="Sai email/pass") |
|
|
|
|
|
if user_data.get("status") != "active": |
|
|
raise HTTPException(status_code=403, detail="Tài khoản chưa kích hoạt") |
|
|
|
|
|
|
|
|
token = secrets.token_urlsafe(32) |
|
|
|
|
|
|
|
|
sessions_ref = db.collection('sessions') |
|
|
old_sessions = sessions_ref.where('email', '==', email).stream() |
|
|
for session_doc in old_sessions: |
|
|
session_doc.reference.delete() |
|
|
|
|
|
|
|
|
sessions_ref.add({ |
|
|
"email": email, |
|
|
"token": token, |
|
|
"last_seen": time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
}) |
|
|
|
|
|
return {"success": True, "token": token, "email": email} |
|
|
|
|
|
@app.post("/api/check-session") |
|
|
async def check_session(email: str = Form(...), token: str = Form(...)): |
|
|
if not db: |
|
|
raise HTTPException(status_code=500, detail="DB Error") |
|
|
|
|
|
|
|
|
sessions_ref = db.collection('sessions') |
|
|
session_docs = list(sessions_ref.where('email', '==', email).where('token', '==', token).limit(1).stream()) |
|
|
|
|
|
if not session_docs: |
|
|
raise HTTPException(status_code=401, detail="Session expired") |
|
|
|
|
|
|
|
|
session_doc = session_docs[0] |
|
|
session_doc.reference.update({ |
|
|
"last_seen": time.strftime("%Y-%m-%d %H:%M:%S") |
|
|
}) |
|
|
|
|
|
return {"status": "valid"} |
|
|
|
|
|
@app.post("/api/logout") |
|
|
async def logout(request: Request): |
|
|
try: |
|
|
data = await request.json() |
|
|
email = data.get("email") |
|
|
|
|
|
if email and db: |
|
|
sessions_ref = db.collection('sessions') |
|
|
old_sessions = sessions_ref.where('email', '==', email).stream() |
|
|
for session_doc in old_sessions: |
|
|
session_doc.reference.delete() |
|
|
except Exception as e: |
|
|
print(f"Logout error: {e}") |
|
|
|
|
|
return {"status": "success"} |
|
|
|
|
|
@app.post("/api/upload-image") |
|
|
async def upload_image(file: UploadFile = File(...)): |
|
|
try: |
|
|
file_ext = os.path.splitext(file.filename)[1] or ".png" |
|
|
file_name = f"{uuid.uuid4().hex}{file_ext}" |
|
|
file_path = f"uploads/{file_name}" |
|
|
with open(file_path, "wb") as f: f.write(await file.read()) |
|
|
return {"url": file_path} |
|
|
except Exception as e: raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
async def process_image_with_gemini(image: Image.Image, model_id: str, prompt: str, max_retries: int = 3) -> str: |
|
|
"""Gửi 1 ảnh (hoặc mảnh ảnh) lên Gemini và nhận text, có fallback""" |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
api_key = key_manager.get_next_key() |
|
|
if not api_key: raise ValueError("No API Key") |
|
|
genai.configure(api_key=api_key) |
|
|
|
|
|
|
|
|
generation_config = {"temperature": 0.0, "top_p": 1.0, "max_output_tokens": 8192} |
|
|
model = genai.GenerativeModel(model_id, generation_config=generation_config) |
|
|
|
|
|
|
|
|
response = model.generate_content([prompt, image]) |
|
|
|
|
|
|
|
|
text = safe_get_text(response) |
|
|
|
|
|
|
|
|
if text == "[BLOCKED_BY_COPYRIGHT]": |
|
|
print(f"Warning: Gemini blocked copyright content. Switching to fallback OCR...") |
|
|
return await fallback_ocr_tesseract(image) |
|
|
|
|
|
if text: |
|
|
return text.strip() |
|
|
|
|
|
except Exception as e: |
|
|
if "429" in str(e) and attempt < max_retries - 1: |
|
|
time.sleep(2) |
|
|
continue |
|
|
if attempt == max_retries - 1: |
|
|
print(f"Error Gemini: {e}") |
|
|
return await fallback_ocr_tesseract(image) |
|
|
|
|
|
return "" |
|
|
|
|
|
async def process_large_image(image: Image.Image, model: str, prompt: str, semaphore: asyncio.Semaphore) -> str: |
|
|
""" |
|
|
Xử lý ảnh lớn: Cắt -> Gửi (kèm fallback) -> Ghép |
|
|
""" |
|
|
CHUNK_HEIGHT = 2048 |
|
|
OVERLAP_HEIGHT = 512 |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
if height <= CHUNK_HEIGHT: |
|
|
async with semaphore: |
|
|
return await process_image_with_gemini(image, model, prompt) |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
y = 0 |
|
|
while y < height: |
|
|
bottom = min(y + CHUNK_HEIGHT, height) |
|
|
box = (0, y, width, bottom) |
|
|
chunk = image.crop(box) |
|
|
chunks.append(chunk) |
|
|
if bottom == height: break |
|
|
y += (CHUNK_HEIGHT - OVERLAP_HEIGHT) |
|
|
|
|
|
print(f"[Split] Image height {height}px -> {len(chunks)} chunks.") |
|
|
|
|
|
async def process_chunk(chunk_img, index): |
|
|
async with semaphore: |
|
|
text = await process_image_with_gemini(chunk_img, model, prompt) |
|
|
return index, text |
|
|
|
|
|
tasks = [process_chunk(chunk, i) for i, chunk in enumerate(chunks)] |
|
|
chunk_results = await asyncio.gather(*tasks) |
|
|
|
|
|
chunk_results.sort(key=lambda x: x[0]) |
|
|
ordered_texts = [text for _, text in chunk_results] |
|
|
|
|
|
final_text = ordered_texts[0] |
|
|
for i in range(1, len(ordered_texts)): |
|
|
final_text = stitch_text(final_text, ordered_texts[i], min_overlap_chars=20) |
|
|
|
|
|
return final_text |
|
|
|
|
|
@app.post("/api/convert") |
|
|
async def convert_file( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
model: str = Form("gemini-2.5-flash"), |
|
|
mode: str = Form("latex") |
|
|
): |
|
|
check_rate_limit(request) |
|
|
if key_manager.get_key_count() == 0: |
|
|
raise HTTPException(status_code=500, detail="Chưa cấu hình API Key") |
|
|
|
|
|
prompt = DIRECT_GEMINI_PROMPT_LATEX if mode == "latex" else DIRECT_GEMINI_PROMPT_TEXT_ONLY |
|
|
|
|
|
try: |
|
|
file_content = await file.read() |
|
|
file_ext = os.path.splitext(file.filename)[1].lower() |
|
|
|
|
|
global_semaphore = asyncio.Semaphore(MAX_THREADS) |
|
|
|
|
|
results = [] |
|
|
|
|
|
if file_ext == ".pdf": |
|
|
doc = fitz.open(stream=file_content, filetype="pdf") |
|
|
async def process_page_wrapper(page, idx): |
|
|
pix = page.get_pixmap(dpi=300) |
|
|
img = Image.open(io.BytesIO(pix.tobytes("png"))) |
|
|
text = await process_large_image(img, model, prompt, global_semaphore) |
|
|
return idx, text |
|
|
|
|
|
tasks = [process_page_wrapper(doc[i], i) for i in range(len(doc))] |
|
|
page_results = await asyncio.gather(*tasks) |
|
|
results = [text for _, text in sorted(page_results, key=lambda x: x[0])] |
|
|
doc.close() |
|
|
|
|
|
elif file_ext in [".png", ".jpg", ".jpeg", ".bmp"]: |
|
|
img = Image.open(io.BytesIO(file_content)) |
|
|
text = await process_large_image(img, model, prompt, global_semaphore) |
|
|
results.append(text) |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Định dạng file không hỗ trợ") |
|
|
|
|
|
final_text = "\n\n".join(results) |
|
|
return {"success": True, "result": clean_latex_formulas(final_text)} |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/export-docx") |
|
|
async def export_docx(markdown_text: str = Form(...)): |
|
|
try: |
|
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp_file: |
|
|
output_filename = tmp_file.name |
|
|
|
|
|
pypandoc.convert_text( |
|
|
markdown_text, |
|
|
to='docx', |
|
|
format='markdown', |
|
|
outputfile=output_filename, |
|
|
extra_args=['--standalone'] |
|
|
) |
|
|
|
|
|
return FileResponse( |
|
|
output_filename, |
|
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", |
|
|
filename="HT_MATH_OUTPUT.docx" |
|
|
) |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"Lỗi xuất Word: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860"))) |