hoangthiencm's picture
Update app.py
762aa13 verified
"""
Backend API cho HT_MATH_WEB - Chạy trên Hugging Face Spaces (Docker Version)
Phiên bản: 9.1 (Strict OCR - Anti-Hallucination)
Tác giả: Hoàng Tấn Thiên
"""
import os
import io
import time
import asyncio
import re
import tempfile
import hashlib
import secrets
import uuid
import math
from typing import List, Optional
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import fitz # PyMuPDF
import google.generativeai as genai
import json
# --- PANDOC IMPORT ---
try:
import pypandoc
print(f"INFO: Pandoc version detected: {pypandoc.get_pandoc_version()}")
except ImportError:
print("CRITICAL WARNING: pypandoc module not found.")
except OSError:
print("CRITICAL WARNING: pandoc binary not found in system path.")
# --- TESSERACT IMPORT (FALLBACK OCR) ---
try:
import pytesseract
# Kiểm tra xem binary có tồn tại không (trong Docker thường ở /usr/bin/tesseract)
# pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
print("INFO: Tesseract OCR module loaded.")
except ImportError:
print("WARNING: pytesseract not found. Fallback OCR will not work.")
pytesseract = None
# --- FIREBASE ---
try:
import firebase_admin
from firebase_admin import credentials, firestore
FIREBASE_AVAILABLE = True
except ImportError:
print("WARNING: firebase-admin not found. Authentication will not work.")
FIREBASE_AVAILABLE = False
firebase_admin = None
credentials = None
firestore = None
# ===== CẤU HÌNH =====
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",")
GEMINI_MODELS = os.getenv("GEMINI_MODELS", "gemini-2.5-flash,gemini-1.5-pro").split(",")
FIREBASE_CREDENTIALS = os.getenv("FIREBASE_CREDENTIALS", "")
MAX_THREADS = int(os.getenv("MAX_THREADS", "5"))
ADMIN_SECRET_KEY = os.getenv("ADMIN_SECRET_KEY", "admin123")
# Setup Firebase Firestore
db = None
if FIREBASE_AVAILABLE and FIREBASE_CREDENTIALS:
try:
# Parse JSON credentials from environment variable
cred_dict = json.loads(FIREBASE_CREDENTIALS)
cred = credentials.Certificate(cred_dict)
firebase_admin.initialize_app(cred)
db = firestore.client()
print("INFO: Firebase Firestore connected successfully")
except Exception as e:
print(f"Warning: Không thể kết nối Firebase: {e}")
else:
if not FIREBASE_AVAILABLE:
print("Warning: firebase-admin package not installed")
if not FIREBASE_CREDENTIALS:
print("Warning: FIREBASE_CREDENTIALS environment variable not set")
app = FastAPI(title="HT_MATH_WEB API", version="9.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- SETUP STATIC FILES ---
os.makedirs("uploads", exist_ok=True)
app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")
@app.exception_handler(404)
async def not_found_handler(request, exc):
return JSONResponse(
status_code=404,
content={
"detail": f"Route not found: {request.url.path}",
"available_routes": ["/", "/api/models", "/api/convert", "/api/export-docx", "/api/login", "/api/check-session", "/api/upload-image"]
}
)
# ===== KEY MANAGER & RATE LIMIT =====
class ApiKeyManager:
def __init__(self, keys: List[str]):
self.api_keys = [k.strip() for k in keys if k.strip()]
self.current_index = 0
def get_next_key(self) -> Optional[str]:
if not self.api_keys: return None
key = self.api_keys[self.current_index]
self.current_index = (self.current_index + 1) % len(self.api_keys)
return key
def get_key_count(self) -> int:
return len(self.api_keys)
key_manager = ApiKeyManager(GEMINI_API_KEYS)
ip_rate_limits = {}
RATE_LIMIT_DURATION = 7
def check_rate_limit(request: Request):
forwarded = request.headers.get("X-Forwarded-For")
client_ip = forwarded.split(",")[0].strip() if forwarded else request.client.host
now = time.time()
if client_ip in ip_rate_limits:
elapsed = now - ip_rate_limits[client_ip]
if elapsed < RATE_LIMIT_DURATION:
print(f"[RateLimit] IP {client_ip} requesting too fast.")
ip_rate_limits[client_ip] = now
# ===== SAFE PROMPTS (STRICT OCR - NO CHATTER) =====
# Đã thay đổi toàn bộ Prompt theo hướng "Máy OCR lạnh lùng"
# Loại bỏ hoàn toàn vai trò "Trợ lý ảo" để tránh nói nhảm.
DIRECT_GEMINI_PROMPT_TEXT_ONLY = r"""Đóng vai một CÔNG CỤ OCR CHUYÊN DỤNG cho văn bản hành chính Việt Nam.
NHIỆM VỤ: Trích xuất nguyên văn (Verbatim) nội dung trong ảnh.
QUY TẮC BẤT DI BẤT DỊCH (CẤM VI PHẠM):
1. TUYỆT ĐỐI KHÔNG thêm lời dẫn đầu (như "Dưới đây là...", "Chào bạn...", "Văn bản này nói về...").
2. TUYỆT ĐỐI KHÔNG thêm lời kết luận hay nhận xét, đánh giá.
3. KHÔNG thay đổi từ ngữ, KHÔNG diễn giải (paraphrase), KHÔNG tóm tắt. Phải giữ nguyên văn kể cả lỗi chính tả của bản gốc.
4. Giữ nguyên cấu trúc dòng, đoạn, số hiệu, ngày tháng, quốc hiệu, tiêu ngữ.
5. Nếu chữ quá mờ không đọc được, hãy điền ký hiệu `[...]`. KHÔNG ĐƯỢC TỰ Ý ĐOÁN MÒ HAY BỊA ĐẶT NỘI DUNG.
ĐẦU RA: Chỉ trả về văn bản thô. Không Markdown dư thừa nếu không cần thiết."""
DIRECT_GEMINI_PROMPT_LATEX = r"""Đóng vai công cụ số hóa tài liệu Toán học chính xác tuyệt đối.
NHIỆM VỤ: Chuyển đổi ảnh sang Markdown + LaTeX.
QUY TẮC:
1. Công thức toán học PHẢI nằm trong dấu `$`. Ví dụ: $x^2 + 1 = 0$.
2. KHÔNG thêm lời dẫn thừa (như "Kết quả là:", "Đây là bài giải:").
3. GIỮ NGUYÊN VĂN đề bài và lời giải. KHÔNG tóm tắt, KHÔNG viết lại theo văn phong khác.
4. Nếu ảnh bị cắt hoặc mờ:
- Nếu là công thức toán có thể suy luận logic: Hãy phục hồi.
- Nếu là văn bản/tên riêng: Điền `[...]`.
CHỈ TRẢ VỀ NỘI DUNG MARKDOWN."""
# ===== STITCHING ALGORITHM =====
def stitch_text(text_a: str, text_b: str, min_overlap_chars: int = 20) -> str:
if not text_a: return text_b
if not text_b: return text_a
a_lines = text_a.splitlines()
b_lines = text_b.splitlines()
scan_window = min(len(a_lines), len(b_lines), 30)
best_overlap_idx = 0
for i in range(scan_window, 0, -1):
tail_a = "\n".join(a_lines[-i:]).strip()
head_b = "\n".join(b_lines[:i]).strip()
if len(tail_a) >= min_overlap_chars and tail_a == head_b:
best_overlap_idx = i
break
if best_overlap_idx > 0:
return text_a + "\n" + "\n".join(b_lines[best_overlap_idx:])
else:
return text_a + "\n\n" + text_b
# ===== HELPER FUNCTIONS =====
def clean_latex_formulas(text: str) -> str:
# Chuẩn hóa khoảng trắng Latex
text = re.sub(r'\$\s+(.*?)\s+\$', lambda m: f'${m.group(1).strip()}$', text)
# Fix lỗi phổ biến khi OCR tiếng Việt bị dính ký tự
return text
def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()
def verify_password(password: str, hashed: str) -> bool:
return hash_password(password) == hashed
def safe_get_text(response) -> str:
"""
Trích xuất text an toàn từ Gemini Response.
Xử lý trường hợp bị chặn bản quyền (finish_reason=4).
"""
if not response.candidates:
return ""
candidate = response.candidates[0]
# Kiểm tra lý do kết thúc
# 1 = STOP (OK), 4 = RECITING_FROM_COPYRIGHTED_MATERIAL (Blocked)
if candidate.finish_reason == 4:
return "[BLOCKED_BY_COPYRIGHT]"
if candidate.finish_reason != 1:
# Có thể log warning ở đây với các reason khác (như SAFETY)
return ""
# Nếu an toàn, trích xuất text
parts = candidate.content.parts
texts = [p.text for p in parts if hasattr(p, "text")]
return "\n".join(texts)
async def fallback_ocr_tesseract(image: Image.Image) -> str:
"""
Fallback dùng Tesseract OCR khi Gemini từ chối phục vụ
"""
if pytesseract is None:
return "**[Lỗi] Gemini từ chối xử lý (Bản quyền) và Tesseract chưa được cài đặt.**"
print("[Fallback] Đang chạy Tesseract OCR...")
try:
# Chạy trong executor để không block event loop
loop = asyncio.get_running_loop()
# Chế độ: tiếng Việt + tiếng Anh + công thức toán (nếu train, ở đây dùng cơ bản)
text = await loop.run_in_executor(None, lambda: pytesseract.image_to_string(image, lang='vie+eng'))
return f"**[Lưu ý: Nội dung này được trích xuất bằng OCR dự phòng do Gemini chặn bản quyền]**\n\n{text}"
except Exception as e:
print(f"Tesseract Error: {e}")
return "**[Lỗi] Cả Gemini và Tesseract đều thất bại.**"
# ===== API ENDPOINTS =====
@app.get("/")
@app.get("/health")
async def root():
pandoc_status = "Not Found"
tesseract_status = "Not Found"
try:
pandoc_status = pypandoc.get_pandoc_version()
except: pass
try:
if pytesseract: tesseract_status = "Ready"
except: pass
return {
"status": "ok",
"service": "HT_MATH_WEB API v9.1 (Strict Mode)",
"keys_loaded": key_manager.get_key_count(),
"pandoc": pandoc_status,
"tesseract": tesseract_status
}
@app.get("/api/models")
async def get_models():
return {"models": GEMINI_MODELS}
# --- AUTH API ---
@app.post("/api/register")
async def register(email: str = Form(...), password: str = Form(...)):
if not db:
raise HTTPException(status_code=500, detail="DB Error")
# Check if user exists
users_ref = db.collection('users')
existing = list(users_ref.where('email', '==', email).limit(1).stream())
if existing:
raise HTTPException(status_code=400, detail="Email tồn tại")
# Create new user
user_data = {
"email": email,
"password": hash_password(password),
"status": "pending",
"created_at": time.strftime("%Y-%m-%d %H:%M:%S")
}
users_ref.add(user_data)
return {"success": True, "message": "Đăng ký thành công, chờ duyệt."}
@app.post("/api/login")
async def login(request: Request, email: str = Form(...), password: str = Form(...)):
if not db:
raise HTTPException(status_code=500, detail="DB Error")
# Find user
users_ref = db.collection('users')
user_docs = list(users_ref.where('email', '==', email).limit(1).stream())
if not user_docs:
raise HTTPException(status_code=401, detail="Sai email/pass")
user_data = user_docs[0].to_dict()
if not verify_password(password, user_data["password"]):
raise HTTPException(status_code=401, detail="Sai email/pass")
if user_data.get("status") != "active":
raise HTTPException(status_code=403, detail="Tài khoản chưa kích hoạt")
# Create session token
token = secrets.token_urlsafe(32)
# Delete old sessions for this email
sessions_ref = db.collection('sessions')
old_sessions = sessions_ref.where('email', '==', email).stream()
for session_doc in old_sessions:
session_doc.reference.delete()
# Create new session
sessions_ref.add({
"email": email,
"token": token,
"last_seen": time.strftime("%Y-%m-%d %H:%M:%S")
})
return {"success": True, "token": token, "email": email}
@app.post("/api/check-session")
async def check_session(email: str = Form(...), token: str = Form(...)):
if not db:
raise HTTPException(status_code=500, detail="DB Error")
# Find session by email and token
sessions_ref = db.collection('sessions')
session_docs = list(sessions_ref.where('email', '==', email).where('token', '==', token).limit(1).stream())
if not session_docs:
raise HTTPException(status_code=401, detail="Session expired")
# Update last_seen timestamp
session_doc = session_docs[0]
session_doc.reference.update({
"last_seen": time.strftime("%Y-%m-%d %H:%M:%S")
})
return {"status": "valid"}
@app.post("/api/logout")
async def logout(request: Request):
try:
data = await request.json()
email = data.get("email")
if email and db:
sessions_ref = db.collection('sessions')
old_sessions = sessions_ref.where('email', '==', email).stream()
for session_doc in old_sessions:
session_doc.reference.delete()
except Exception as e:
print(f"Logout error: {e}")
return {"status": "success"}
@app.post("/api/upload-image")
async def upload_image(file: UploadFile = File(...)):
try:
file_ext = os.path.splitext(file.filename)[1] or ".png"
file_name = f"{uuid.uuid4().hex}{file_ext}"
file_path = f"uploads/{file_name}"
with open(file_path, "wb") as f: f.write(await file.read())
return {"url": file_path}
except Exception as e: raise HTTPException(status_code=500, detail=str(e))
# --- CORE CONVERT LOGIC ---
async def process_image_with_gemini(image: Image.Image, model_id: str, prompt: str, max_retries: int = 3) -> str:
"""Gửi 1 ảnh (hoặc mảnh ảnh) lên Gemini và nhận text, có fallback"""
for attempt in range(max_retries):
try:
api_key = key_manager.get_next_key()
if not api_key: raise ValueError("No API Key")
genai.configure(api_key=api_key)
# Giảm temperature xuống 0.0 để model trở nên máy móc nhất có thể (Deterministic)
generation_config = {"temperature": 0.0, "top_p": 1.0, "max_output_tokens": 8192}
model = genai.GenerativeModel(model_id, generation_config=generation_config)
# Gọi API
response = model.generate_content([prompt, image])
# Lấy text an toàn
text = safe_get_text(response)
# XỬ LÝ LỖI BẢN QUYỀN -> FALLBACK
if text == "[BLOCKED_BY_COPYRIGHT]":
print(f"Warning: Gemini blocked copyright content. Switching to fallback OCR...")
return await fallback_ocr_tesseract(image)
if text:
return text.strip()
except Exception as e:
if "429" in str(e) and attempt < max_retries - 1:
time.sleep(2)
continue
if attempt == max_retries - 1:
print(f"Error Gemini: {e}")
return await fallback_ocr_tesseract(image)
return ""
async def process_large_image(image: Image.Image, model: str, prompt: str, semaphore: asyncio.Semaphore) -> str:
"""
Xử lý ảnh lớn: Cắt -> Gửi (kèm fallback) -> Ghép
"""
CHUNK_HEIGHT = 2048
OVERLAP_HEIGHT = 512
width, height = image.size
if height <= CHUNK_HEIGHT:
async with semaphore:
return await process_image_with_gemini(image, model, prompt)
# Cắt ảnh
chunks = []
y = 0
while y < height:
bottom = min(y + CHUNK_HEIGHT, height)
box = (0, y, width, bottom)
chunk = image.crop(box)
chunks.append(chunk)
if bottom == height: break
y += (CHUNK_HEIGHT - OVERLAP_HEIGHT)
print(f"[Split] Image height {height}px -> {len(chunks)} chunks.")
async def process_chunk(chunk_img, index):
async with semaphore:
text = await process_image_with_gemini(chunk_img, model, prompt)
return index, text
tasks = [process_chunk(chunk, i) for i, chunk in enumerate(chunks)]
chunk_results = await asyncio.gather(*tasks)
chunk_results.sort(key=lambda x: x[0])
ordered_texts = [text for _, text in chunk_results]
final_text = ordered_texts[0]
for i in range(1, len(ordered_texts)):
final_text = stitch_text(final_text, ordered_texts[i], min_overlap_chars=20)
return final_text
@app.post("/api/convert")
async def convert_file(
request: Request,
file: UploadFile = File(...),
model: str = Form("gemini-2.5-flash"),
mode: str = Form("latex")
):
check_rate_limit(request)
if key_manager.get_key_count() == 0:
raise HTTPException(status_code=500, detail="Chưa cấu hình API Key")
prompt = DIRECT_GEMINI_PROMPT_LATEX if mode == "latex" else DIRECT_GEMINI_PROMPT_TEXT_ONLY
try:
file_content = await file.read()
file_ext = os.path.splitext(file.filename)[1].lower()
global_semaphore = asyncio.Semaphore(MAX_THREADS)
results = []
if file_ext == ".pdf":
doc = fitz.open(stream=file_content, filetype="pdf")
async def process_page_wrapper(page, idx):
pix = page.get_pixmap(dpi=300)
img = Image.open(io.BytesIO(pix.tobytes("png")))
text = await process_large_image(img, model, prompt, global_semaphore)
return idx, text
tasks = [process_page_wrapper(doc[i], i) for i in range(len(doc))]
page_results = await asyncio.gather(*tasks)
results = [text for _, text in sorted(page_results, key=lambda x: x[0])]
doc.close()
elif file_ext in [".png", ".jpg", ".jpeg", ".bmp"]:
img = Image.open(io.BytesIO(file_content))
text = await process_large_image(img, model, prompt, global_semaphore)
results.append(text)
else:
raise HTTPException(status_code=400, detail="Định dạng file không hỗ trợ")
final_text = "\n\n".join(results)
return {"success": True, "result": clean_latex_formulas(final_text)}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/export-docx")
async def export_docx(markdown_text: str = Form(...)):
try:
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp_file:
output_filename = tmp_file.name
pypandoc.convert_text(
markdown_text,
to='docx',
format='markdown',
outputfile=output_filename,
extra_args=['--standalone']
)
return FileResponse(
output_filename,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
filename="HT_MATH_OUTPUT.docx"
)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Lỗi xuất Word: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))