translate / app.py
mistpe's picture
Update app.py
8949f84 verified
# main.py
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Body
from fastapi.staticfiles import StaticFiles
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import asyncio
import aiohttp
import json
import webbrowser
from typing import List, Dict, Optional
import os
from pathlib import Path
import pypdf
from docx import Document
from docx.shared import Inches, Pt
import markdown
import base64
from pydantic import BaseModel
import threading
from typing import List, Optional
import time
import hashlib
import re
from io import BytesIO
# 创建必要的目录
os.makedirs("static", exist_ok=True)
os.makedirs("temp", exist_ok=True)
os.makedirs("translation_memory", exist_ok=True)
app = FastAPI()
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
class DocumentSegment(BaseModel):
text: str
type: str
format: Dict
position: Dict
translated: str = ""
alternatives: List[str] = []
terminology: Dict[str, str] = {}
confidence: float = 0.0
review_status: str = "pending"
class TranslationRequest(BaseModel):
text: str
source_lang: str
target_lang: str
use_memory: bool = True
use_terminology: bool = True
class ExportRequest(BaseModel):
segments: List[DocumentSegment]
format: str
mode: str
source_file_type: str
class TranslationMemory:
def __init__(self):
self.memory_file = "translation_memory/memory.json"
self.load_memory()
def load_memory(self):
if os.path.exists(self.memory_file):
with open(self.memory_file, 'r', encoding='utf-8') as f:
self.memory = json.load(f)
else:
self.memory = {}
self.save_memory()
def save_memory(self):
with open(self.memory_file, 'w', encoding='utf-8') as f:
json.dump(self.memory, f, ensure_ascii=False, indent=2)
def get_translation(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
key = f"{source_lang}_{target_lang}_{hashlib.md5(text.encode()).hexdigest()}"
return self.memory.get(key, {}).get('translation')
def add_translation(self, text: str, translation: str, source_lang: str, target_lang: str):
key = f"{source_lang}_{target_lang}_{hashlib.md5(text.encode()).hexdigest()}"
self.memory[key] = {
'text': text,
'translation': translation,
'timestamp': time.time()
}
self.save_memory()
class TerminologyManager:
def __init__(self):
self.terminology_file = "translation_memory/terminology.json"
self.load_terminology()
def load_terminology(self):
if os.path.exists(self.terminology_file):
with open(self.terminology_file, 'r', encoding='utf-8') as f:
self.terminology = json.load(f)
else:
self.terminology = {}
self.save_terminology()
def save_terminology(self):
with open(self.terminology_file, 'w', encoding='utf-8') as f:
json.dump(self.terminology, f, ensure_ascii=False, indent=2)
def get_terminology(self, source_lang: str, target_lang: str) -> Dict[str, str]:
key = f"{source_lang}_{target_lang}"
return self.terminology.get(key, {})
def add_term(self, source_term: str, target_term: str, source_lang: str, target_lang: str):
key = f"{source_lang}_{target_lang}"
if key not in self.terminology:
self.terminology[key] = {}
self.terminology[key][source_term] = target_term
self.save_terminology()
class DocumentProcessor:
@staticmethod
async def extract_text(file: UploadFile) -> List[DocumentSegment]:
content = await file.read()
file_ext = file.filename.split('.')[-1].lower()
segments = []
if file_ext == 'txt':
text = content.decode('utf-8')
# 优化分段逻辑,支持更多分隔符
paragraphs = re.split(r'\n\s*\n|\r\n\s*\r\n', text)
for i, para in enumerate(paragraphs):
if para.strip():
segments.append(DocumentSegment(
text=para.strip(),
type='paragraph',
format={'font': 'default', 'style': 'normal'},
position={'index': i}
))
elif file_ext == 'pdf':
temp_path = f"temp/{file.filename}"
with open(temp_path, 'wb') as f:
f.write(content)
reader = pypdf.PdfReader(temp_path)
current_font = None
current_size = None
for i, page in enumerate(reader.pages):
text = page.extract_text()
paragraphs = text.split('\n\n')
for j, para in enumerate(paragraphs):
if para.strip():
segments.append(DocumentSegment(
text=para.strip(),
type='paragraph',
format={
'font': current_font or 'default',
'size': current_size or 12,
'page': i + 1
},
position={'page': i, 'index': j}
))
os.remove(temp_path)
elif file_ext == 'docx':
temp_path = f"temp/{file.filename}"
with open(temp_path, 'wb') as f:
f.write(content)
doc = Document(temp_path)
for i, para in enumerate(doc.paragraphs):
if para.text.strip():
format_info = {
'style': para.style.name,
'alignment': str(para.alignment),
'font': para.style.font.name if para.style.font else 'default',
'size': para.style.font.size if para.style.font else 12,
'bold': any(run.bold for run in para.runs),
'italic': any(run.italic for run in para.runs)
}
segments.append(DocumentSegment(
text=para.text.strip(),
type='paragraph' if not para.style.name.startswith('Heading') else 'heading',
format=format_info,
position={'index': i}
))
os.remove(temp_path)
elif file_ext == 'md':
text = content.decode('utf-8')
lines = text.split('\n')
current_segment = []
for i, line in enumerate(lines):
if line.strip():
current_segment.append(line)
elif current_segment:
segment_text = '\n'.join(current_segment)
format_info = {
'type': 'markdown',
'headings': bool(re.match(r'^#+\s', segment_text)),
'lists': bool(re.match(r'^[-*+]\s', segment_text)),
'code': bool(re.match(r'^```', segment_text))
}
segments.append(DocumentSegment(
text=segment_text,
type='markdown',
format=format_info,
position={'index': len(segments)}
))
current_segment = []
if current_segment:
segment_text = '\n'.join(current_segment)
segments.append(DocumentSegment(
text=segment_text,
type='markdown',
format={'type': 'markdown'},
position={'index': len(segments)}
))
return segments
class DocumentExporter:
def __init__(self, segments, source_file_type):
self.segments = segments
self.source_file_type = source_file_type
def export_txt(self, mode='translated'):
if mode == 'translated':
content = '\n\n'.join(seg['translated'] or seg['text'] for seg in self.segments)
else: # 对照模式
content = ''
for seg in self.segments:
content += f"原文:{seg['text']}\n"
content += f"译文:{seg['translated']}\n"
content += f"{'=' * 50}\n\n"
return content.encode('utf-8')
def export_docx(self, mode='translated'):
doc = Document()
section = doc.sections[0]
section.page_width = Inches(11.69) # A4 width
section.page_height = Inches(8.27) # A4 height
if mode == 'translated':
for seg in self.segments:
p = doc.add_paragraph()
if seg['format'].get('style'):
try:
p.style = seg['format']['style']
except:
pass
p.add_run(seg['translated'] or seg['text'])
else: # 对照模式
table = doc.add_table(rows=1, cols=2)
table.style = 'Table Grid'
header_cells = table.rows[0].cells
header_cells[0].text = '原文'
header_cells[1].text = '译文'
for seg in self.segments:
row_cells = table.add_row().cells
row_cells[0].text = seg['text']
row_cells[1].text = seg['translated'] or ''
if seg['format'].get('style'):
try:
for cell in row_cells:
cell.paragraphs[0].style = seg['format']['style']
except:
pass
# 保存到临时BytesIO
temp_bio = BytesIO()
doc.save(temp_bio)
return temp_bio.getvalue()
def export_markdown(self, mode='translated'):
if mode == 'translated':
content = []
for seg in self.segments:
if seg['format'].get('type') == 'markdown':
content.append(seg['translated'] or seg['text'])
else:
content.append(seg['translated'] or seg['text'])
return '\n\n'.join(content).encode('utf-8')
else:
content = []
for seg in self.segments:
content.append('### 原文\n')
content.append(seg['text'])
content.append('\n### 译文\n')
content.append(seg['translated'] or '')
content.append('\n---\n')
return '\n'.join(content).encode('utf-8')
def export_html(self, mode='translated'):
css = """
<style>
.translation-wrapper { max-width: 1200px; margin: 0 auto; padding: 20px; }
.segment { margin-bottom: 20px; }
.parallel { display: flex; gap: 20px; }
.source, .target { flex: 1; padding: 10px; background: #f9f9f9; border-radius: 4px; }
h3 { color: #666; font-size: 0.9em; margin-bottom: 5px; }
</style>
"""
if mode == 'translated':
content = [
'<!DOCTYPE html><html><head><meta charset="UTF-8">',
css,
'</head><body><div class="translation-wrapper">'
]
for seg in self.segments:
content.append(f'<div class="segment">{seg["translated"] or seg["text"]}</div>')
content.append('</div></body></html>')
else: # 对照模式
content = [
'<!DOCTYPE html><html><head><meta charset="UTF-8">',
css,
'</head><body><div class="translation-wrapper">'
]
for seg in self.segments:
content.append('<div class="segment parallel">')
content.append(f'<div class="source"><h3>原文</h3>{seg["text"]}</div>')
content.append(
f'<div class="target"><h3>译文</h3>{seg["translated"] or ""}</div>'
)
content.append('</div>')
content.append('</div></body></html>')
return '\n'.join(content).encode('utf-8')
def export(self, format='auto', mode='translated'):
if format == 'auto':
format = self.source_file_type or 'txt'
if format == 'txt':
return {
'content': self.export_txt(mode),
'mimetype': 'text/plain',
'extension': 'txt'
}
elif format == 'docx':
return {
'content': self.export_docx(mode),
'mimetype': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'extension': 'docx'
}
elif format == 'md':
return {
'content': self.export_markdown(mode),
'mimetype': 'text/markdown',
'extension': 'md'
}
elif format == 'html':
return {
'content': self.export_html(mode),
'mimetype': 'text/html',
'extension': 'html'
}
else:
raise ValueError(f'Unsupported format: {format}')
class TranslationManager:
def __init__(self):
self.memory = TranslationMemory()
self.terminology = TerminologyManager()
async def translate_text(self, text: str, source_lang: str = "AUTO", target_lang: str = "ZH") -> Dict:
# 首先检查翻译记忆
memory_translation = self.memory.get_translation(text, source_lang, target_lang)
if memory_translation:
return {
'translated': memory_translation,
'alternatives': [],
'from_memory': True
}
# 应用术语库替换
terms = self.terminology.get_terminology(source_lang, target_lang)
text_to_translate = text
replacements = {}
for source_term, target_term in terms.items():
if source_term in text_to_translate:
placeholder = f"__TERM_{len(replacements)}__"
replacements[placeholder] = target_term
text_to_translate = text_to_translate.replace(source_term, placeholder)
# async with aiohttp.ClientSession() as session:
# try:
# # DeepL翻译
# async with session.post(
# 'https://api.deeplx.org/..../translate',
# json={
# "text": text_to_translate,
# "source_lang": source_lang,
# "target_lang": target_lang
# }
# ) as response:
# result = await response.json()
# if result.get('code') == 200:
# translated_text = result['data']
# # 恢复术语替换
# for placeholder, term in replacements.items():
# translated_text = translated_text.replace(placeholder, term)
# # 保存到翻译记忆
# self.memory.add_translation(text, translated_text, source_lang, target_lang)
# return {
# 'translated': translated_text,
# 'alternatives': result.get('alternatives', []),
# 'from_memory': False,
# 'confidence': 0.8 if replacements else 0.7
# }
# else:
# raise HTTPException(status_code=500, detail="Translation API error")
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
async with aiohttp.ClientSession() as session:
try:
# 使用环境变量获取Deepl翻译API的URL
deepl_api_url = os.environ.get('DEEPL_API_URL')
if not deepl_api_url:
raise ValueError("DEEPL_API_URL environment variable is not set.")
async with session.post(
deepl_api_url,
json={
"text": text_to_translate,
"source_lang": source_lang,
"target_lang": target_lang
}
) as response:
result = await response.json()
if result.get('code') == 200:
translated_text = result['data']
# 恢复术语替换
for placeholder, term in replacements.items():
translated_text = translated_text.replace(placeholder, term)
# 保存到翻译记忆
self.memory.add_translation(text, translated_text, source_lang, target_lang)
return {
'translated': translated_text,
'alternatives': result.get('alternatives', []),
'from_memory': False,
'confidence': 0.8 if replacements else 0.7
}
else:
raise HTTPException(status_code=500, detail="Translation API error")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def translate_segments(self, segments: List[DocumentSegment], source_lang: str, target_lang: str) -> List[DocumentSegment]:
translated_segments = []
for segment in segments:
if not segment.translated: # 只翻译未翻译的段落
result = await self.translate_text(segment.text, source_lang, target_lang)
segment.translated = result['translated']
segment.alternatives = result['alternatives']
segment.confidence = result.get('confidence', 0.7)
segment.review_status = 'from_memory' if result.get('from_memory') else 'machine_translated'
await asyncio.sleep(1) # 控制请求频率
translated_segments.append(segment)
return translated_segments
# 全局翻译管理器实例
translation_manager = TranslationManager()
@app.post("/upload")
async def upload_file(
file: UploadFile = File(...),
source_lang: str = Form("AUTO"),
target_lang: str = Form("ZH")
):
processor = DocumentProcessor()
segments = await processor.extract_text(file)
# 记录原始文件类型
file_type = file.filename.split('.')[-1].lower()
return {
"segments": [seg.dict() for seg in segments],
"source_file_type": file_type
}
@app.post("/translate")
async def translate(
segments: List[DocumentSegment],
source_lang: str = Body("AUTO"),
target_lang: str = Body("ZH")
):
translated_segments = await translation_manager.translate_segments(segments, source_lang, target_lang)
return {"segments": [seg.dict() for seg in translated_segments]}
@app.post("/translate_text")
async def translate_text(request: TranslationRequest):
result = await translation_manager.translate_text(
request.text,
request.source_lang,
request.target_lang
)
return result
@app.post("/add_term")
async def add_term(
source_term: str = Form(...),
target_term: str = Form(...),
source_lang: str = Form(...),
target_lang: str = Form(...)
):
translation_manager.terminology.add_term(source_term, target_term, source_lang, target_lang)
return {"status": "success"}
@app.get("/get_terminology")
async def get_terminology(source_lang: str, target_lang: str):
terms = translation_manager.terminology.get_terminology(source_lang, target_lang)
return {"terminology": terms}
@app.post("/export")
async def export_document(request: ExportRequest):
try:
exporter = DocumentExporter(
[seg.dict() for seg in request.segments],
request.source_file_type
)
result = exporter.export(request.format, request.mode)
return StreamingResponse(
BytesIO(result['content']),
media_type=result['mimetype'],
headers={
'Content-Disposition': f'attachment; filename=translated_document.{result["extension"]}'
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def read_root():
return FileResponse('static/index.html')
def open_browser():
webbrowser.open('http://localhost:7860')
if __name__ == "__main__":
# 启动浏览器
threading.Timer(1.5, open_browser).start()
# 启动FastAPI服务
uvicorn.run(app, host="0.0.0.0", port=7860)