| import os |
| |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" |
| os.environ["HF_HOME"] = "/tmp/huggingface" |
| |
| os.environ["PADDLEOCR_HOME"] = "/tmp/paddleocr_models" |
| os.makedirs("/tmp/paddleocr_models", exist_ok=True) |
|
|
| import uuid |
| import time |
| from PIL import Image, ImageDraw, ImageFont |
| from flask_cors import CORS |
| from paddleocr import PaddleOCR |
| from flask import jsonify, Flask, request |
| from transformers import MarianMTModel, MarianTokenizer |
| import base64 |
| from io import BytesIO |
| import logging |
| |
| from collections import defaultdict |
| import traceback |
|
|
| |
| |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[%(asctime)s] [%(levelname)s] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| ) |
|
|
| |
| request_counter = defaultdict(int) |
| |
| translate_image_counter = 0 |
|
|
|
|
| |
| ocr = PaddleOCR( |
| use_doc_orientation_classify=False, |
| use_doc_unwarping=False, |
| use_textline_orientation=False, |
| det_model_dir='/tmp/paddleocr_models/det', |
| rec_model_dir='/tmp/paddleocr_models/rec', |
| cls_model_dir='/tmp/paddleocr_models/cls' |
| ) |
|
|
| def image_to_base64(img_path): |
| img = Image.open(img_path).convert('RGB') |
| buffered = BytesIO() |
| img.save(buffered, format="JPEG") |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
| |
| def load_chinese_font(size=24): |
| font_url = "https://download944.mediafire.com/3mi7ach6g0lg9aqEAzgtEzOOfk2nsbzstr7nld3IHMhXHM8a9hutdmxLMFNaCZe9viFybL5OzR28Py4bS8ObvWID4TT57RWJygUJUJjzo4M9tVFLIORZZrSLLGuLQob6eWTI6DRwbQI-pJCepqQVihMDJYsjrtSHP0wV681303kROA/1ttntwtmzwwdaiv/simfang.ttf" |
| font_dir = "/tmp/font" |
| font_path = os.path.join(font_dir, "simfang.ttf") |
| print(font_path) |
|
|
| try: |
| |
| if not os.path.exists(font_path): |
| os.makedirs(font_dir, exist_ok=True) |
| print(f"正在从网络下载字体: {font_url}") |
| r = request.get(font_url) |
| with open(font_path, "wb") as f: |
| f.write(r.content) |
| print("✅ 字体下载完成") |
|
|
| |
| return ImageFont.load_default() |
| except Exception as e: |
| print(f"⚠️ 加载中文字体失败: {e}") |
| return ImageFont.load_default() |
|
|
|
|
| |
| MODEL_MAP = { |
| "zh-en": "Helsinki-NLP/opus-mt-zh-en", |
| "en-zh": "Helsinki-NLP/opus-mt-en-zh", |
| "zh-de": "Helsinki-NLP/opus-mt-zh-de", |
| "de-zh": "Helsinki-NLP/opus-mt-de-zh", |
| "ja-en": "Helsinki-NLP/opus-mt-ja-en", |
| "en-ja": "Helsinki-NLP/opus-mt-en-jap", |
| } |
|
|
| |
| models = {} |
| tokenizers = {} |
|
|
|
|
| |
| def preload_models(): |
| print("🚀 正在预加载所有翻译模型...") |
| for model_name in set(MODEL_MAP.values()): |
| try: |
| |
| tokenizer = MarianTokenizer.from_pretrained(model_name) |
| model = MarianMTModel.from_pretrained(model_name) |
| tokenizers[model_name] = tokenizer |
| models[model_name] = model |
| print(f"✅ 模型已加载: {model_name}") |
| except Exception as e: |
| print(f"❌ 加载失败: {model_name} -> {e}") |
|
|
|
|
| def load_model(model_name): |
| |
| if model_name not in models or model_name not in tokenizers: |
| logging.warning(f"⚠️ 模型 {model_name} 未预加载,正在动态加载!") |
| tokenizer = MarianTokenizer.from_pretrained(model_name) |
| model = MarianMTModel.from_pretrained(model_name) |
| tokenizers[model_name] = tokenizer |
| models[model_name] = model |
| return tokenizers[model_name], models[model_name] |
|
|
|
|
| def translate_text(text: str, src_lang: str, tgt_lang: str): |
| start_time = time.time() |
| |
| lang_key = f"{src_lang}-{tgt_lang}" |
| request_counter[lang_key] += 1 |
| model_name = MODEL_MAP.get(lang_key) |
|
|
| if not model_name: |
| return jsonify({"error": f"Translation from {src_lang} to {tgt_lang} is not supported."}), 400 |
|
|
| try: |
| tokenizer, model = load_model(model_name) |
| inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=512) |
| translated = model.generate(**inputs, max_length=512) |
| translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] |
|
|
| elapsed = round((time.time() - start_time) * 1000, 2) |
| client_ip = request.remote_addr or "unknown" |
|
|
| |
| logging.info( |
| f"[{client_ip}] {src_lang}->{tgt_lang} | \"{text[:30]}\" -> \"{translated_text[:30]}\" | {elapsed}ms | Total: {request_counter[lang_key]}" |
| ) |
| |
| except Exception as e: |
| print(str(e)) |
| return jsonify({"error": str(e)}), 500 |
|
|
| return jsonify({ |
| "translated_text": translated_text, |
| "source_language": src_lang, |
| "target_language": tgt_lang |
| }) |
|
|
|
|
| def translate_image(image_path: str, src_lang: str, tgt_lang: str): |
| global translate_image_counter |
| translate_image_counter += 1 |
| request_id = uuid.uuid4().hex[:8] |
|
|
| lang_key = f"{src_lang}-{tgt_lang}" |
| print(lang_key) |
| model_name = MODEL_MAP.get(lang_key) |
| print(model_name) |
| if not model_name: |
| logging.warning(f"[{request_id}] ❌ 不支持的语言对: {lang_key}") |
| return jsonify({"error": f"Unsupported language pair: {lang_key}"}), 400 |
|
|
| try: |
| |
| print('ocr 识别 start') |
| result = ocr.ocr(image_path) |
|
|
| original_img = Image.open(image_path).convert("RGB") |
| draw = ImageDraw.Draw(original_img) |
|
|
| print(f"[DEBUG] tgt_lang = '{tgt_lang}'") |
| font = ImageFont.load_default() |
| |
| |
| |
| |
| |
| |
| |
| translated_results = [] |
| ocr_result = result[0] |
|
|
| for box, (orig_text, score) in ocr_result: |
| logging.info(f"orig_text: {orig_text}") |
| try: |
| tokenizer, model = load_model(model_name) |
| inputs = tokenizer([orig_text], return_tensors="pt", padding=True, truncation=True, max_length=512) |
| translated = model.generate(**inputs, max_length=512) |
| translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] |
| logging.info(f"[{request_id}] 🔁 翻译: {translated_text}") |
| |
| translated_results.append({ |
| 'poly': box, |
| 'original': orig_text, |
| 'translated': translated_text |
| }) |
| |
| draw.polygon(box, outline="red", width=1) |
| draw.text((box[0][0], box[0][1] - 15), translated_text, fill="blue", font=font) |
| except Exception as e: |
| logging.error(f"[{request_id}] ❗ 翻译失败: {e}") |
| |
| buffered = BytesIO() |
| original_img.save(buffered, format="JPEG") |
| img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| print(f"[{request_id}] 🖼️ 图像翻译完成,成功生成 base64 图像") |
| return jsonify({ |
| "image_base64": img_base64 |
| }) |
|
|
| except Exception as e: |
| logging.error(f"[{request_id}] ❗ 整体翻译失败: {str(e)}") |
| traceback.print_exc() |
| return jsonify({"error": f"OCR error: {str(e)}"}), 500 |
|
|
|
|
| |
| @app.route("/ai/api_translate_text", methods=["POST"]) |
| def api_translate_text(): |
| |
| data = request.get_json() |
| text = data.get("text", "") |
| src_lang = data.get("src_lang", "") |
| tgt_lang = data.get("tgt_lang", "") |
| print(text,src_lang,tgt_lang) |
|
|
| if not text or not src_lang or not tgt_lang: |
| return jsonify({"error": "text, src_lang, tgt_lang are required"}), 400 |
|
|
| try: |
| results = translate_text(text, src_lang, tgt_lang) |
| print(results) |
| return results |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
|
|
| |
| |
| @app.route("/ai/api_translate_text_g", methods=["GET"]) |
| def api_translate_text_g(): |
| |
| text = request.args.get("text", "") |
| src_lang = request.args.get("src_lang", "") |
| tgt_lang = request.args.get("tgt_lang", "") |
| print(text,src_lang,tgt_lang) |
|
|
| if not text or not src_lang or not tgt_lang: |
| return jsonify({"error": "text, src_lang, tgt_lang are required"}), 400 |
|
|
| try: |
| results = translate_text(text, src_lang, tgt_lang) |
| print(results) |
| return results |
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
| |
|
|
| |
| |
| @app.route("/ai/api_translate_image", methods=["POST"]) |
| def api_translate_image(): |
| |
| if 'image' not in request.files: |
| return jsonify({"error": "No image provided"}), 400 |
| file = request.files['image'] |
| uid = uuid.uuid4().hex |
| logging.info(uid) |
| filename = f"{uid}.jpg" |
| filepath = os.path.join("/tmp", filename) |
| logging.info(filepath) |
|
|
| src_lang = request.form.get("src_lang") |
| tgt_lang = request.form.get("tgt_lang") |
| logging.info(tgt_lang) |
| |
| if not src_lang or not tgt_lang: |
| return jsonify({"error": "src_lang, tgt_lang are required"}), 400 |
| file.save(filepath) |
| result = translate_image(filepath, src_lang, tgt_lang) |
| |
| if os.path.exists(filepath): |
| os.remove(filepath) |
| return result |
|
|
|
|
| |
| |
| @app.route("/ai/api_test_g", methods=["GET"]) |
| def api_test_g(): |
| return jsonify({"test": "ok!"}) |
|
|
|
|
| |
| |
| @app.route("/ai/api_translate_stats", methods=["GET"]) |
| def api_translate_stats(): |
| return jsonify(dict(request_counter)) |
|
|
|
|
| |
| |
| @app.route("/ai/api_translate_image_status", methods=["GET"]) |
| def api_translate_image_status(): |
| return jsonify({"image_translate_requests": translate_image_counter}) |
|
|
|
|
| |
| preload_models() |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |
|
|
|
|
|
|