Sakalti commited on
Commit
484ac3f
·
verified ·
1 Parent(s): 3f9f2a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import onnx
4
+ import gradio as gr
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from huggingface_hub import HfApi, HfFolder, Repository
7
+
8
+ # モデルをONNXに変換する関数
9
+ def convert_to_onnx_custom_tensors(model_name, tensor_list, output_path="model.onnx"):
10
+ # モデルとトークナイザをロード
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModel.from_pretrained(model_name)
13
+
14
+ # ユーザー指定のテンソル内容をパースして作成
15
+ try:
16
+ input_tensors = []
17
+ for tensor_str in tensor_list:
18
+ input_tensors.append(torch.tensor(eval(tensor_str)))
19
+
20
+ # ダミーの例として、最初の2つのテンソルをinput_idsとattention_maskとして使用
21
+ input_ids = input_tensors[0] if len(input_tensors) > 0 else torch.zeros((1, 1))
22
+ attention_mask = input_tensors[1] if len(input_tensors) > 1 else torch.ones((1, 1))
23
+
24
+ except Exception as e:
25
+ return f"テンソル作成中にエラーが発生しました: {e}"
26
+
27
+ # モデルをONNXにエクスポート
28
+ try:
29
+ torch.onnx.export(
30
+ model,
31
+ (input_ids, attention_mask),
32
+ output_path,
33
+ input_names=["input_ids", "attention_mask"],
34
+ output_names=["output"],
35
+ dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
36
+ opset_version=11
37
+ )
38
+ return f"モデルをONNX形式で{output_path}にエクスポートしました。"
39
+ except Exception as e:
40
+ return f"ONNX形式へのエクスポート中にエラーが発生しました: {e}"
41
+
42
+ # Hugging Face Hubにデプロイする関数
43
+ def deploy_to_huggingface(access_token, repo_name, model_path="model.onnx"):
44
+ # Hugging Face Hubへの認証
45
+ HfFolder.save_token(access_token)
46
+ api = HfApi()
47
+
48
+ # リポジトリを作成 (もし存在しない場合)
49
+ try:
50
+ if not api.model_info(repo_name, use_auth_token=access_token):
51
+ api.create_repo(repo_name, exist_ok=True, token=access_token)
52
+
53
+ # リポジトリにファイルをアップロード
54
+ repo = Repository(local_dir=repo_name, clone_from=repo_name, use_auth_token=True)
55
+ repo.git_pull()
56
+ repo.lfs_track(["*.onnx"])
57
+ os.makedirs(repo_name, exist_ok=True)
58
+ model_output_path = os.path.join(repo_name, "model.onnx")
59
+ os.rename(model_path, model_output_path)
60
+
61
+ repo.git_add()
62
+ repo.git_commit("Add ONNX model")
63
+ repo.git_push()
64
+ return f"{repo_name}にONNXモデルをデプロイしました。"
65
+ except Exception as e:
66
+ return f"デプロイ中にエラーが発生しました: {e}"
67
+
68
+ # AIの基本的な仕組みの解説
69
+ def ai_explanation():
70
+ explanation = """
71
+ AI(人工知能)は、人間の知的作業を模倣する技術です。
72
+ 特にディープラーニングはニューラルネットワークを用いて、大量のデータから特徴を学習し、分類や予測などを行うことができます。
73
+ モデルはトレーニングフェーズでデータを用いて学習し、その後の推論フェーズで新しいデータに対して応答を生成します。
74
+ ONNXは異なるフレームワーク間で互換性を持たせるための形式で、PyTorchやTensorFlowのモデルを統一的に使うことができます。
75
+ """
76
+ return explanation
77
+
78
+ # Gradioインターフェースの作成
79
+ def main():
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# ONNXモデル変換・デプロイツールとAIの解説")
82
+
83
+ # AIの基本説明セクション
84
+ with gr.Tab("AIの基本説明"):
85
+ explanation = gr.Markdown(ai_explanation())
86
+
87
+ # ONNX変換セクション
88
+ with gr.Tab("ONNXモデルの変換"):
89
+ model_name = gr.Textbox(label="モデル名", placeholder="例: bert-base-uncased")
90
+ tensor_list = gr.Dataframe(headers=["テンソルの内容"], datatype="str", row_count=3, col_count=1, placeholder="例: [[101, 2057, 2024, 102]]")
91
+ convert_btn = gr.Button("ONNXに変換")
92
+ output = gr.Textbox(label="出力メッセージ")
93
+
94
+ convert_btn.click(
95
+ convert_to_onnx_custom_tensors,
96
+ inputs=[model_name, tensor_list],
97
+ outputs=output
98
+ )
99
+
100
+ # デプロイセクション
101
+ with gr.Tab("Hugging Faceにデプロイ"):
102
+ access_token = gr.Textbox(label="Hugging Faceアクセストークン", type="password")
103
+ repo_name = gr.Textbox(label="リポジトリ名")
104
+ deploy_btn = gr.Button("デプロイ")
105
+ deploy_output = gr.Textbox(label="デプロイ出力")
106
+
107
+ deploy_btn.click(
108
+ deploy_to_huggingface,
109
+ inputs=[access_token, repo_name],
110
+ outputs=deploy_output
111
+ )
112
+
113
+ demo.launch()
114
+
115
+ if __name__ == "__main__":
116
+ main()