seawolf2357 commited on
Commit
75330da
·
verified ·
1 Parent(s): 712cd5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -23
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # モデルID
6
  model_id = "tencent/HY-MT1.5-1.8B"
7
 
8
  # 環境に合わせてデバイスと精度を自動選択
9
- # Freeスペース(CPU)の場合はfloat32、GPUがある場合はfloat16を使用
10
  if torch.cuda.is_available():
11
  device = "cuda"
12
  dtype = torch.float16
@@ -20,17 +20,41 @@ print(f"Loading model on {device} with {dtype}...")
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
- device_map=device, # autoではなく明示的に指定
24
  torch_dtype=dtype
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def translate_text(source_text, target_lang):
 
 
 
 
28
  # プロンプトの切り替えロジック
29
  if "Chinese" in target_lang or "中文" in target_lang:
30
  prompt = f"将以下文本翻译为{target_lang},注意只需要输出翻译后的结果,不要额外解释:\n{source_text}"
31
  else:
32
  prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n{source_text}"
33
-
34
  messages = [{"role": "user", "content": prompt}]
35
 
36
  # 入力処理
@@ -40,7 +64,7 @@ def translate_text(source_text, target_lang):
40
  add_generation_prompt=False,
41
  return_tensors="pt"
42
  ).to(device)
43
-
44
  # 生成実行
45
  with torch.no_grad():
46
  generated_ids = model.generate(
@@ -50,7 +74,7 @@ def translate_text(source_text, target_lang):
50
  top_p=0.6,
51
  repetition_penalty=1.05
52
  )
53
-
54
  # 出力処理
55
  input_length = text_input.shape[1]
56
  response = generated_ids[0][input_length:]
@@ -58,27 +82,137 @@ def translate_text(source_text, target_lang):
58
 
59
  return decoded_output
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # UIの構築
62
- langs = ["Japanese", "English", "Chinese", "Korean", "French", "German", "Spanish"]
63
 
64
- with gr.Blocks() as demo:
65
- gr.Markdown("# 🚀 HY-MT1.5-1.8B Translator (Spaces)")
66
- gr.Markdown("Tencent1.8Bモデルを使用した翻訳デモです。")
67
-
68
- with gr.Row():
69
- with gr.Column():
70
- input_text = gr.Textbox(label="原文 (Source Text)", lines=5, placeholder="ここに入力...")
71
- target_lang = gr.Dropdown(choices=langs, value="English", label="翻���先 (Target Language)")
72
- submit_btn = gr.Button("翻訳 (Translate)", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- with gr.Column():
75
- output_text = gr.Textbox(label="結果 (Result)", lines=5, interactive=False)
76
-
77
- submit_btn.click(
78
- fn=translate_text,
79
- inputs=[input_text, target_lang],
80
- outputs=output_text
81
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # 起動
84
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import fitz # PyMuPDF
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # モデルID
7
  model_id = "tencent/HY-MT1.5-1.8B"
8
 
9
  # 環境に合わせてデバイスと精度を自動選択
 
10
  if torch.cuda.is_available():
11
  device = "cuda"
12
  dtype = torch.float16
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
+ device_map=device,
24
  torch_dtype=dtype
25
  )
26
 
27
+ def extract_text_from_pdf(pdf_file):
28
+ """PDF에서 텍스트 추출"""
29
+ if pdf_file is None:
30
+ return ""
31
+
32
+ try:
33
+ doc = fitz.open(pdf_file.name)
34
+ full_text = ""
35
+
36
+ for page_num, page in enumerate(doc, 1):
37
+ text = page.get_text("text")
38
+ if text.strip():
39
+ full_text += f"\n--- Page {page_num} ---\n{text.strip()}\n"
40
+
41
+ doc.close()
42
+ return full_text.strip()
43
+
44
+ except Exception as e:
45
+ return f"❌ PDF 추출 오류: {str(e)}"
46
+
47
  def translate_text(source_text, target_lang):
48
+ """텍스트 번역"""
49
+ if not source_text or not source_text.strip():
50
+ return "입력 텍스트가 없습니다."
51
+
52
  # プロンプトの切り替えロジック
53
  if "Chinese" in target_lang or "中文" in target_lang:
54
  prompt = f"将以下文本翻译为{target_lang},注意只需要输出翻译后的结果,不要额外解释:\n{source_text}"
55
  else:
56
  prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n{source_text}"
57
+
58
  messages = [{"role": "user", "content": prompt}]
59
 
60
  # 入力処理
 
64
  add_generation_prompt=False,
65
  return_tensors="pt"
66
  ).to(device)
67
+
68
  # 生成実行
69
  with torch.no_grad():
70
  generated_ids = model.generate(
 
74
  top_p=0.6,
75
  repetition_penalty=1.05
76
  )
77
+
78
  # 出力処理
79
  input_length = text_input.shape[1]
80
  response = generated_ids[0][input_length:]
 
82
 
83
  return decoded_output
84
 
85
+ def translate_long_text(source_text, target_lang, chunk_size=1500):
86
+ """긴 텍스트를 청크로 나눠서 번역"""
87
+ if not source_text or not source_text.strip():
88
+ return "입력 텍스트가 없습니다."
89
+
90
+ # 짧은 텍스트는 바로 번역
91
+ if len(source_text) <= chunk_size:
92
+ return translate_text(source_text, target_lang)
93
+
94
+ # 긴 텍스트는 문단 단위로 분할
95
+ paragraphs = source_text.split('\n\n')
96
+ chunks = []
97
+ current_chunk = ""
98
+
99
+ for para in paragraphs:
100
+ if len(current_chunk) + len(para) < chunk_size:
101
+ current_chunk += para + "\n\n"
102
+ else:
103
+ if current_chunk:
104
+ chunks.append(current_chunk.strip())
105
+ current_chunk = para + "\n\n"
106
+
107
+ if current_chunk:
108
+ chunks.append(current_chunk.strip())
109
+
110
+ # 각 청크 번역
111
+ translated_chunks = []
112
+ for i, chunk in enumerate(chunks):
113
+ print(f"Translating chunk {i+1}/{len(chunks)}...")
114
+ translated = translate_text(chunk, target_lang)
115
+ translated_chunks.append(translated)
116
+
117
+ return "\n\n".join(translated_chunks)
118
+
119
+ def process_pdf_and_translate(pdf_file, target_lang):
120
+ """PDF 업로드 → 텍스트 추출 → 번역"""
121
+ if pdf_file is None:
122
+ return "", "PDF 파일을 업로드해주세요."
123
+
124
+ # 텍스트 추출
125
+ extracted_text = extract_text_from_pdf(pdf_file)
126
+
127
+ if extracted_text.startswith("❌"):
128
+ return "", extracted_text
129
+
130
+ if not extracted_text.strip():
131
+ return "", "PDF에서 텍스트를 추출할 수 없습니다."
132
+
133
+ # 번역
134
+ translated_text = translate_long_text(extracted_text, target_lang)
135
+
136
+ return extracted_text, translated_text
137
+
138
+ def translate_input_text(source_text, target_lang):
139
+ """입력 텍스트 번역"""
140
+ return translate_long_text(source_text, target_lang)
141
+
142
  # UIの構築
143
+ langs = ["Japanese", "English", "Chinese", "Korean", "French", "German", "Spanish", "한국어", "日本語", "中文"]
144
 
145
+ with gr.Blocks(title="HY-MT1.5 Translator") as demo:
146
+ gr.Markdown("# 🚀 HY-MT1.5-1.8B Translator")
147
+ gr.Markdown("Tencent1.8B 번역 모델을 사용한 텍스트/PDF 번역 데모입니다.")
148
+
149
+ with gr.Tabs():
150
+ # ============ Tab 1: 텍스트 번역 ============
151
+ with gr.TabItem("📝 Text Translation"):
152
+ with gr.Row():
153
+ with gr.Column():
154
+ input_text = gr.Textbox(
155
+ label="원문 (Source Text)",
156
+ lines=10,
157
+ placeholder="번역할 텍스트를 입력하세요..."
158
+ )
159
+ target_lang_text = gr.Dropdown(
160
+ choices=langs,
161
+ value="English",
162
+ label="번역 언어 (Target Language)"
163
+ )
164
+ translate_btn = gr.Button("🔄 번역 (Translate)", variant="primary")
165
+
166
+ with gr.Column():
167
+ output_text = gr.Textbox(
168
+ label="번역 결과 (Result)",
169
+ lines=10,
170
+ interactive=False
171
+ )
172
+
173
+ translate_btn.click(
174
+ fn=translate_input_text,
175
+ inputs=[input_text, target_lang_text],
176
+ outputs=output_text
177
+ )
178
 
179
+ # ============ Tab 2: PDF 번역 ============
180
+ with gr.TabItem("📄 PDF Translation"):
181
+ gr.Markdown("### PDF 파일을 업로드하면 텍스트를 추출하고 번역합니다.")
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ pdf_input = gr.File(
186
+ label="📄 PDF 파일 업로드",
187
+ file_types=[".pdf"]
188
+ )
189
+ target_lang_pdf = gr.Dropdown(
190
+ choices=langs,
191
+ value="English",
192
+ label="번역 언어 (Target Language)"
193
+ )
194
+ translate_pdf_btn = gr.Button("🔄 PDF 번역", variant="primary")
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ extracted_text = gr.Textbox(
199
+ label="📋 추출된 원문 (Extracted Text)",
200
+ lines=15,
201
+ interactive=False
202
+ )
203
+
204
+ with gr.Column():
205
+ translated_pdf_text = gr.Textbox(
206
+ label="📋 번역 결과 (Translated Text)",
207
+ lines=15,
208
+ interactive=False
209
+ )
210
+
211
+ translate_pdf_btn.click(
212
+ fn=process_pdf_and_translate,
213
+ inputs=[pdf_input, target_lang_pdf],
214
+ outputs=[extracted_text, translated_pdf_text]
215
+ )
216
 
217
  # 起動
218
  demo.launch()