Spaces:
Running
Running
| import os | |
| import re | |
| import importlib | |
| from pathlib import Path | |
| import httpx | |
| import uvicorn | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| from typing import Literal, Optional, List | |
| import filetype | |
| from starlette.staticfiles import StaticFiles | |
| # Import order matters. We rely on config.toml written by space_boot.sh | |
| from meme_generator.config import meme_config | |
| import meme_generator.utils as _utils | |
| from meme_generator import load_meme, load_memes | |
| from meme_generator.app import app, register_routers | |
| from meme_generator.manager import get_meme, get_memes | |
| from meme_generator.exception import NoSuchMeme, MemeFeedback | |
| from meme_generator.utils import MemeProperties, render_meme_list | |
| cfg_home = os.environ.get('XDG_CONFIG_HOME') | |
| cfg_dir = (cfg_home + '/meme_generator') if cfg_home else '<unset>' | |
| print(f"[bootstrap] XDG_CONFIG_HOME={cfg_home} CONFIG_DIR={cfg_dir}", flush=True) | |
| tc = getattr(meme_config, 'translate', None) | |
| env_provider = os.getenv('TRANSLATOR_PROVIDER', '').strip().lower() | |
| env_base_url = os.getenv('OPENAI_BASE_URL', '').strip() | |
| env_api_key = os.getenv('OPENAI_API_KEY', '').strip() | |
| env_model = os.getenv('OPENAI_MODEL', '').strip() | |
| if tc: | |
| print( | |
| ( | |
| f"[bootstrap] meme_config.translate provider={getattr(tc,'provider',None)} base_url={getattr(tc,'openai_base_url',None)} " | |
| f"model={getattr(tc,'openai_model',None)} api_key_present={bool(getattr(tc,'openai_api_key',None))}\n" | |
| f"[bootstrap] env override provider={env_provider or '<empty>'} base_url={env_base_url or '<empty>'} " | |
| f"model={env_model or '<empty>'} api_key_present={bool(env_api_key)}" | |
| ), | |
| flush=True, | |
| ) | |
| else: | |
| print( | |
| f"[bootstrap] meme_config.translate missing; env provider={env_provider or '<empty>'} base_url={env_base_url or '<empty>'} model={env_model or '<empty>'} api_key_present={bool(env_api_key)}", | |
| flush=True, | |
| ) | |
| _orig_translate = _utils.translate | |
| def _openai_translate(text: str, lang_from: str = "auto", lang_to: str = "zh") -> str: | |
| tc = getattr(meme_config, "translate", None) | |
| # Prefer env, then config; support upstream without openai fields | |
| provider = (os.getenv("TRANSLATOR_PROVIDER", "") or getattr(tc, "provider", "")).strip().lower() | |
| base_url = (os.getenv("OPENAI_BASE_URL", "") or getattr(tc, "openai_base_url", "")).strip() | |
| api_key = (os.getenv("OPENAI_API_KEY", "") or getattr(tc, "openai_api_key", "")).strip() | |
| model = (os.getenv("OPENAI_MODEL", "") or getattr(tc, "openai_model", "")).strip() | |
| use_openai = provider == "openai" or (not provider and bool(api_key)) | |
| if not use_openai: | |
| print(f"[translate] fallback to original provider (provider={provider})", flush=True) | |
| return _orig_translate(text, lang_from, lang_to) | |
| if not base_url or not api_key or not model: | |
| raise MemeFeedback("OpenAI 翻译未配置完整:请设置 openai_base_url / openai_api_key / openai_model") | |
| lang_map = { | |
| "zh": "Chinese", | |
| "zh-cn": "Chinese", | |
| "zh-hans": "Chinese", | |
| "en": "English", | |
| "jp": "Japanese", | |
| "ja": "Japanese", | |
| "ko": "Korean", | |
| "fr": "French", | |
| "de": "German", | |
| "ru": "Russian", | |
| "es": "Spanish", | |
| } | |
| target_lang = lang_map.get((lang_to or "").lower(), lang_to) | |
| url = base_url.rstrip("/") + "/chat/completions" | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| # Build a stronger instruction, especially for Japanese output | |
| extra_rules = "" | |
| if target_lang.lower().startswith("japan"): | |
| extra_rules = ( | |
| " Use natural Japanese. Prefer including hiragana/katakana (kana) where appropriate; " | |
| "do not just output Chinese hanzi. If the input is Chinese, do not return the same text." | |
| ) | |
| system_prompt = ( | |
| "You are a professional translation engine." | |
| f" Translate the user text to {target_lang}." | |
| " Only output the translated text without any extra words, quotes, or explanations." | |
| " Preserve numbers, emoji, and links." + extra_rules | |
| ) | |
| payload = { | |
| "model": model, | |
| "temperature": 0, | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": text}, | |
| ], | |
| } | |
| try: | |
| print( | |
| f"[translate/openai] request target={target_lang} model={model} base_url={base_url} text_len={len(text)}", | |
| flush=True, | |
| ) | |
| r = httpx.post(url, headers=headers, json=payload, timeout=60) | |
| print(f"[translate/openai] response status={r.status_code}", flush=True) | |
| r.raise_for_status() | |
| data = r.json() | |
| choices = data.get("choices") or [] | |
| content = choices[0].get("message", {}).get("content") if choices else None | |
| if not content: | |
| raise MemeFeedback("OpenAI 翻译失败:空结果") | |
| result = str(content).strip() | |
| # If target is Japanese but no kana detected, attempt a single retry with stricter rule | |
| def has_kana(s: str) -> bool: | |
| return bool(re.search(r"[\u3040-\u30FF]", s)) | |
| if target_lang.lower().startswith("japan") and not has_kana(result): | |
| print("[translate/openai] no kana detected, retrying with stricter Japanese rule", flush=True) | |
| strict_prompt = ( | |
| system_prompt | |
| + " Ensure the output contains kana (hiragana or katakana) and is not identical to the input." | |
| ) | |
| payload_retry = { | |
| "model": model, | |
| "temperature": 0, | |
| "messages": [ | |
| {"role": "system", "content": strict_prompt}, | |
| {"role": "user", "content": text}, | |
| ], | |
| } | |
| r2 = httpx.post(url, headers=headers, json=payload_retry, timeout=60) | |
| print(f"[translate/openai] retry response status={r2.status_code}", flush=True) | |
| r2.raise_for_status() | |
| data2 = r2.json() | |
| choices2 = data2.get("choices") or [] | |
| result2 = (choices2[0].get("message", {}).get("content") or "").strip() if choices2 else "" | |
| if result2: | |
| print(f"[translate/openai] retry success result_len={len(result2)}", flush=True) | |
| return result2 | |
| print(f"[translate/openai] success result_len={len(result)}", flush=True) | |
| return result | |
| except Exception as e: | |
| raise MemeFeedback(f"OpenAI 翻译失败: {e}") | |
| # Monkey-patch before loading any memes | |
| _utils.translate = _openai_translate | |
| print("[bootstrap] translate() monkey-patched for OpenAI provider", flush=True) | |
| # Load builtin memes from source tree if available (ensures assets are present); | |
| # otherwise fall back to the installed site-packages copy. | |
| src_memes_dir = Path("/app/meme-generator/meme_generator/memes") | |
| if src_memes_dir.exists(): | |
| print(f"[bootstrap] Loading builtin memes from source: {src_memes_dir}", flush=True) | |
| load_memes(str(src_memes_dir)) | |
| else: | |
| pkg_dir = Path(importlib.import_module('meme_generator').__file__).parent | |
| memes_dir = pkg_dir / 'memes' | |
| print(f"[bootstrap] Loading builtin memes from package: {memes_dir}", flush=True) | |
| if memes_dir.exists(): | |
| for path in memes_dir.iterdir(): | |
| if path.is_dir(): | |
| load_meme(f"meme_generator.memes.{path.name}") | |
| # Optional override for /memes/render_list (fresh data each time) | |
| class MemeKeyWithProperties(BaseModel): | |
| meme_key: str | |
| disabled: bool = False | |
| labels: List[Literal["new", "hot"]] = [] | |
| class RenderMemeListRequest(BaseModel): | |
| meme_list: Optional[List[MemeKeyWithProperties]] = None | |
| text_template: str = "{keywords}" | |
| add_category_icon: bool = True | |
| def render_list(params: RenderMemeListRequest = RenderMemeListRequest()): | |
| try: | |
| if params.meme_list: | |
| meme_list = [ | |
| ( | |
| get_meme(p.meme_key), | |
| MemeProperties(disabled=p.disabled, labels=p.labels), | |
| ) | |
| for p in params.meme_list | |
| ] | |
| else: | |
| meme_list = [ | |
| (m, MemeProperties()) for m in sorted(get_memes(), key=lambda m: m.key) | |
| ] | |
| except NoSuchMeme as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| result = render_meme_list( | |
| meme_list, | |
| text_template=params.text_template, | |
| add_category_icon=params.add_category_icon, | |
| ) | |
| content = result.getvalue() | |
| media_type = str(filetype.guess_mime(content)) or "text/plain" | |
| from fastapi import Response | |
| return Response(content=content, media_type=media_type) | |
| # Dynamic infos.json and keyMap.json derived from loaded memes (not aggregated assets) | |
| from collections import OrderedDict | |
| from meme_generator.app import MemeInfoResponse, MemeParamsResponse | |
| def build_infos_and_keymap(): | |
| infos = {} | |
| pairs = [] # (keyword, meme_key) | |
| for meme in sorted(get_memes(), key=lambda m: m.key): | |
| args_type_response = None | |
| if meme.params_type.args_type: | |
| args_model = meme.params_type.args_type.args_model | |
| args_type_response = { | |
| "args_model": args_model.model_json_schema() if hasattr(args_model, "model_json_schema") else {}, | |
| "args_examples": [ | |
| getattr(x, "model_dump", lambda: x)() if hasattr(x, "model_dump") else x | |
| for x in meme.params_type.args_type.args_examples | |
| ], | |
| "parser_options": meme.params_type.args_type.parser_options, | |
| } | |
| infos[meme.key] = { | |
| "key": meme.key, | |
| "params_type": { | |
| "min_images": meme.params_type.min_images, | |
| "max_images": meme.params_type.max_images, | |
| "min_texts": meme.params_type.min_texts, | |
| "max_texts": meme.params_type.max_texts, | |
| "default_texts": meme.params_type.default_texts, | |
| "args_type": args_type_response, | |
| }, | |
| "keywords": meme.keywords, | |
| "shortcuts": meme.shortcuts, | |
| "tags": list(meme.tags), | |
| "date_created": meme.date_created, | |
| "date_modified": meme.date_modified, | |
| } | |
| for kw in meme.keywords: | |
| pairs.append((kw, meme.key)) | |
| keymap = OrderedDict() | |
| for kw, key in sorted(pairs, key=lambda x: len(x[0]), reverse=True): | |
| if kw not in keymap: | |
| keymap[kw] = key | |
| return infos, keymap | |
| def infos_json(): | |
| infos, _ = build_infos_and_keymap() | |
| return infos | |
| def keymap_json(): | |
| _, keymap = build_infos_and_keymap() | |
| return keymap | |
| # Load contrib + emoji | |
| load_memes("/app/meme-generator-contrib/memes") | |
| load_memes("/app/meme_emoji/emoji") | |
| # Optionally load NSFW pack if present in the base image | |
| nsfw_dir = Path("/app/meme_emoji_nsfw/emoji") | |
| if nsfw_dir.exists(): | |
| load_memes(str(nsfw_dir)) | |
| print("[bootstrap] Loaded NSFW meme pack", flush=True) | |
| else: | |
| print("[bootstrap] NSFW meme pack not found; skipping", flush=True) | |
| jj_dir = Path("/app/meme-generator-jj/memes") | |
| if jj_dir.exists(): | |
| load_memes(str(jj_dir)) | |
| print("[bootstrap] Loaded JJ meme pack", flush=True) | |
| else: | |
| print("[bootstrap] JJ meme pack not found; skipping", flush=True) | |
| register_routers() | |
| # Static mount for aggregated data if present | |
| data_dir = os.environ.get("MEME_DATA_DIR", "/app/data") | |
| app.mount("/memes/static", StaticFiles(directory=data_dir), name="static") | |
| port = int(os.environ.get("PORT", "7860")) | |
| print(f"[bootstrap] Starting uvicorn on 0.0.0.0:{port}", flush=True) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |