perceptrontechnology commited on
Commit
712cd5b
·
verified ·
1 Parent(s): f0d4638

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # モデルID
6
+ model_id = "tencent/HY-MT1.5-1.8B"
7
+
8
+ # 環境に合わせてデバイスと精度を自動選択
9
+ # Freeスペース(CPU)の場合はfloat32、GPUがある場合はfloat16を使用
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ dtype = torch.float16
13
+ else:
14
+ device = "cpu"
15
+ dtype = torch.float32
16
+
17
+ print(f"Loading model on {device} with {dtype}...")
18
+
19
+ # トークナイザーとモデルの読み込み
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ device_map=device, # autoではなく明示的に指定
24
+ torch_dtype=dtype
25
+ )
26
+
27
+ def translate_text(source_text, target_lang):
28
+ # プロンプトの切り替えロジック
29
+ if "Chinese" in target_lang or "中文" in target_lang:
30
+ prompt = f"将以下文本翻译为{target_lang},注意只需要输出翻译后的结果,不要额外解释:\n{source_text}"
31
+ else:
32
+ prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n{source_text}"
33
+
34
+ messages = [{"role": "user", "content": prompt}]
35
+
36
+ # 入力処理
37
+ text_input = tokenizer.apply_chat_template(
38
+ messages,
39
+ tokenize=True,
40
+ add_generation_prompt=False,
41
+ return_tensors="pt"
42
+ ).to(device)
43
+
44
+ # 生成実行
45
+ with torch.no_grad():
46
+ generated_ids = model.generate(
47
+ text_input,
48
+ max_new_tokens=1024,
49
+ temperature=0.7,
50
+ top_p=0.6,
51
+ repetition_penalty=1.05
52
+ )
53
+
54
+ # 出力処理
55
+ input_length = text_input.shape[1]
56
+ response = generated_ids[0][input_length:]
57
+ decoded_output = tokenizer.decode(response, skip_special_tokens=True)
58
+
59
+ return decoded_output
60
+
61
+ # UIの構築
62
+ langs = ["Japanese", "English", "Chinese", "Korean", "French", "German", "Spanish"]
63
+
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("# 🚀 HY-MT1.5-1.8B Translator (Spaces)")
66
+ gr.Markdown("Tencentの1.8Bモデルを使用した翻訳デモです。")
67
+
68
+ with gr.Row():
69
+ with gr.Column():
70
+ input_text = gr.Textbox(label="原文 (Source Text)", lines=5, placeholder="ここに入力...")
71
+ target_lang = gr.Dropdown(choices=langs, value="English", label="翻訳先 (Target Language)")
72
+ submit_btn = gr.Button("翻訳 (Translate)", variant="primary")
73
+
74
+ with gr.Column():
75
+ output_text = gr.Textbox(label="結果 (Result)", lines=5, interactive=False)
76
+
77
+ submit_btn.click(
78
+ fn=translate_text,
79
+ inputs=[input_text, target_lang],
80
+ outputs=output_text
81
+ )
82
+
83
+ # 起動
84
+ demo.launch()