File size: 2,723 Bytes
712cd5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# モデルID
model_id = "tencent/HY-MT1.5-1.8B"

# 環境に合わせてデバイスと精度を自動選択
# Freeスペース(CPU)の場合はfloat32、GPUがある場合はfloat16を使用
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
else:
    device = "cpu"
    dtype = torch.float32

print(f"Loading model on {device} with {dtype}...")

# トークナイザーとモデルの読み込み
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device, # autoではなく明示的に指定
    torch_dtype=dtype
)

def translate_text(source_text, target_lang):
    # プロンプトの切り替えロジック
    if "Chinese" in target_lang or "中文" in target_lang:
        prompt = f"将以下文本翻译为{target_lang},注意只需要输出翻译后的结果,不要额外解释:\n{source_text}"
    else:
        prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n{source_text}"

    messages = [{"role": "user", "content": prompt}]
    
    # 入力処理
    text_input = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=False, 
        return_tensors="pt"
    ).to(device)

    # 生成実行
    with torch.no_grad():
        generated_ids = model.generate(
            text_input,
            max_new_tokens=1024,
            temperature=0.7,
            top_p=0.6,
            repetition_penalty=1.05
        )

    # 出力処理
    input_length = text_input.shape[1]
    response = generated_ids[0][input_length:]
    decoded_output = tokenizer.decode(response, skip_special_tokens=True)
    
    return decoded_output

# UIの構築
langs = ["Japanese", "English", "Chinese", "Korean", "French", "German", "Spanish"]

with gr.Blocks() as demo:
    gr.Markdown("# 🚀 HY-MT1.5-1.8B Translator (Spaces)")
    gr.Markdown("Tencentの1.8Bモデルを使用した翻訳デモです。")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="原文 (Source Text)", lines=5, placeholder="ここに入力...")
            target_lang = gr.Dropdown(choices=langs, value="English", label="翻訳先 (Target Language)")
            submit_btn = gr.Button("翻訳 (Translate)", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="結果 (Result)", lines=5, interactive=False)

    submit_btn.click(
        fn=translate_text,
        inputs=[input_text, target_lang],
        outputs=output_text
    )

# 起動
demo.launch()