| import gradio as gr |
| import torch |
| import gc, os |
| os.environ["RWKV_V7_ON"] = '1' |
| os.environ["RWKV_JIT_ON"] = '1' |
| os.environ["RWKV_CUDA_ON"] = '1' |
| from rwkv_rocm.model import RWKV |
| from rwkv_rocm.utils import PIPELINE, PIPELINE_ARGS |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ctx_limit = 4096 |
| gen_limit = 4096 |
|
|
| |
|
|
| model_path_v6 = "./RWKV_v7_G1_0.4B_Translate_JpZh_ctx1024_20251206.pth" |
| model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy='cuda fp16') |
| pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424") |
|
|
| args = model_v6.args |
|
|
| penalty_decay = 0.996 |
|
|
| def evaluate( |
| ctx, |
| token_count=200, |
| temperature=1.0, |
| top_p=0.7, |
| presencePenalty = 0.1, |
| countPenalty = 0.1, |
| ): |
| args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), |
| alpha_frequency = countPenalty, |
| alpha_presence = presencePenalty, |
| token_ban = [], |
| token_stop = [0]) |
| ctx = ctx.strip() |
| all_tokens = [] |
| out_last = 0 |
| out_str = '' |
| occurrence = {} |
| state = None |
| for i in range(int(token_count)): |
|
|
| input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token] |
| out, state = model_v6.forward(input_ids, state) |
| for n in occurrence: |
| out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) |
|
|
| token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p) |
| if token in args.token_stop: |
| break |
| all_tokens += [token] |
| for xxx in occurrence: |
| occurrence[xxx] *= penalty_decay |
| |
| ttt = pipeline_v6.decode([token]) |
| www = 1 |
| if ttt in ' \t0123456789': |
| www = 0 |
| |
| |
| if token not in occurrence: |
| occurrence[token] = www |
| else: |
| occurrence[token] += www |
| |
| tmp = pipeline_v6.decode(all_tokens[out_last:]) |
| if '\ufffd' not in tmp: |
| out_str += tmp |
| yield out_str.strip() |
| out_last = i + 1 |
| del out |
| del state |
| gc.collect() |
| torch.cuda.empty_cache() |
| yield out_str.strip() |
|
|
| def translate_Japanese_to_chinese(Japanese_text, token_count, temperature, top_p, presence_penalty, count_penalty): |
| if not Japanese_text.strip(): |
| return "Chinese:\n请输入日文内容。" |
| |
| full_prompt = f"Japanese: {Japanese_text}\n\nChinese:" |
| for output in evaluate(full_prompt, token_count, temperature, top_p, presence_penalty, count_penalty): |
| yield output |
|
|
| def translate_chinese_to_chinses(Chinese_text, token_count, temperature, top_p, presence_penalty, count_penalty): |
| if not Chinese_text.strip(): |
| return "Chinses:\n请输入中文内容。" |
| |
| full_prompt = f"Chinese: {Chinese_text}\n\nJapanese:" |
| for output in evaluate(full_prompt, token_count, temperature, top_p, presence_penalty, count_penalty): |
| yield output |
|
|
|
|
| with gr.Blocks(title="RWKV_v7_G1_1.5B_Translate_ctx4096 Japanese -> Chinese") as demo: |
| with gr.Tab("Japanese To Chinses"): |
| gr.HTML(f"<div style='text-align:center;'><h1>RWKV_v7_G1_1.5B_Translate_ctx4096_2025062 Japanese -> Chinese</h1></div>") |
| with gr.Row(): |
| with gr.Column(): |
| Japanese_input = gr.Textbox( |
| label="日文输入(注意不能有空行)", |
| lines=20, |
| placeholder="请输入日文内容...", |
| value="ROCmはオープンソーススタックであり、主にグラフィックス・プロセッシング・ユニット(GPU)コンピューティング向けに設計されたオープンソースソフトウェアで構成されています。ROCmは、低レベルカーネルからエンドユーザーアプリケーションまで、GPUプログラミングを可能にするドライバー、開発ツール、APIスイートで構成されています。" |
| "ROCmを使用すると、GPUソフトウェアを特定のニーズに合わせてカスタマイズできます。無料、オープンソース、統合型、かつ安全なソフトウェアエコシステム内で、アプリケーションの開発、共同作業、テスト、導入が可能です。ROCmは、GPUアクセラレーションを活用したハイパフォーマンスコンピューティング(HPC)、人工知能(AI)、科学技術計算、コンピュータ支援設計(CAD)に特に適しています。" |
| "ROCmは、AMDのPortable Graphics Interface(HIP)を搭載しています。これは、オープンソースのC++ GPUプログラミング環境と、それに対応するランタイムです。HIPにより、ROCm開発者は、専用ゲーミングGPUからエクサスケールHPCクラスターまで、幅広いプラットフォームにコードをデプロイすることで、異なるプラットフォーム間で移植可能なアプリケーションを作成できます。" |
| "ROCmはOpenMPやOpenCLなどのプログラミングモデルをサポートし、必要なオープンソースソフトウェアコンパイラ、デバッガー、ライブラリをすべて備えています。ROCmは、PyTorchやTensorFlowなどの機械学習(ML)フレームワークに完全に統合されています。" |
| ) |
|
|
| with gr.Column(): |
| chinese_output = gr.Textbox( |
| label="中文输出", |
| lines=20, |
| placeholder="翻译结果将显示在此处", |
| value="" |
| ) |
|
|
| with gr.Row(): |
| translate_btn = gr.Button("Translate", variant="primary") |
| clear_btn = gr.Button("Clear", variant="secondary") |
| stop_btn = gr.Button("Stop", variant="stop") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit) |
| temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0) |
| top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0) |
| presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0) |
| count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0) |
|
|
| translate_event = translate_btn.click( |
| fn=translate_Japanese_to_chinese, |
| inputs=[Japanese_input, token_count, temperature, top_p, presence_penalty, count_penalty], |
| outputs=[chinese_output] |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: ("", ""), |
| inputs=[], |
| outputs=[Japanese_input, chinese_output] |
| ) |
|
|
| stop_btn.click( |
| fn=None, |
| inputs=None, |
| outputs=None, |
| cancels=[translate_event] |
| ) |
| with gr.Tab("Chinses To Japanese"): |
| gr.HTML(f"<div style='text-align:center;'><h1>RWKV_v7_G1_1.5B_Translate_ctx4096 Chinses -> Japanese</h1></div>") |
| with gr.Row(): |
| with gr.Column(): |
| chinese_input = gr.Textbox( |
| label="中文输入(注意不能有空行)", |
| lines=20, |
| placeholder="请输入中文内容...", |
| value="ROCm是一个开源栈,主要由开源软件组成,旨在用于图形处理单元(GPU)计算。ROCm由一系列驱动程序、开发工具和API组成,这些工具和API允许从低级内核到最终用户应用程序对GPU进行编程。" |
| "使用ROCm,您可以根据您的特定需求定制GPU软件。您可以在一个免费、开源、集成和安全的软件生态系统中开发、协作、测试和部署应用程序。ROCm特别适合GPU加速的高性能计算(HPC)、人工智能(AI)、科学计算和计算机辅助设计(CAD)。" |
| "ROCm由AMD的可移植性图形处理接口(HIP)驱动,这是一个开源的C++ GPU编程环境及其相应的运行时。HIP允许ROCm开发者在不同平台上创建可移植应用程序,通过在从专用游戏GPU到exascale HPC集群的各种平台上部署代码来实现这一目标。" |
| "ROCm支持编程模型,如OpenMP和OpenCL,并包含所有必要的开源软件编译器、调试器和库。ROCm完全集成到机器学习(ML)框架中,如PyTorch和TensorFlow。" |
| ) |
|
|
| with gr.Column(): |
| Japanese_output = gr.Textbox( |
| label="日文输出", |
| lines=20, |
| placeholder="翻译结果将显示在此处", |
| value="" |
| ) |
|
|
| with gr.Row(): |
| translate_btn = gr.Button("Translate", variant="primary") |
| clear_btn = gr.Button("Clear", variant="secondary") |
| stop_btn = gr.Button("Stop", variant="stop") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit) |
| temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0) |
| top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0) |
| presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0) |
| count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0) |
|
|
| translate_event = translate_btn.click( |
| fn=translate_chinese_to_chinses, |
| inputs=[chinese_input, token_count, temperature, top_p, presence_penalty, count_penalty], |
| outputs=[Japanese_output] |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: ("", ""), |
| inputs=[], |
| outputs=[chinese_input, Japanese_output] |
| ) |
|
|
| stop_btn.click( |
| fn=None, |
| inputs=None, |
| outputs=None, |
| cancels=[translate_event] |
| ) |
|
|
| demo.queue(max_size=10, default_concurrency_limit=1) |
| demo.launch(share=False) |
|
|