minoD commited on
Commit
8d094d7
·
verified ·
1 Parent(s): bae6a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -1,74 +1,79 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import os
5
- import shutil
 
 
6
 
7
  model_name = "minoD/JURAN"
8
 
9
- # モデルのロード
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
- device_map="auto",
13
- torch_dtype=torch.float16
 
14
  )
15
-
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
17
 
18
  # プロンプトテンプレートの準備
19
  def generate_prompt(F):
20
- # input キーの代わりに Q F を使用
21
- result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.
22
-
23
- ### 質問:
24
- {F}
25
-
26
- ### 回答:
27
- """ # 回答セクションを追加
28
- # 改行→<NL>
29
  result = result.replace('\n', '<NL>')
30
  return result
31
 
32
  # テキスト生成関数の定義
 
33
  def generate2(F=None, maxTokens=256):
34
- # 推論
35
- prompt = generate_prompt(F)
36
- input_ids = tokenizer(prompt,
37
- return_tensors="pt",
38
- truncation=True,
39
- add_special_tokens=False).input_ids.to(model.device)
40
- outputs = model.generate(
41
- input_ids=input_ids,
42
- max_new_tokens=maxTokens,
43
- do_sample=True,
44
- temperature=0.7,
45
- top_p=0.75,
46
- top_k=40,
47
- no_repeat_ngram_size=2,
48
- )
49
- outputs = outputs[0].tolist()
50
- decoded = tokenizer.decode(outputs)
51
-
52
- # EOSトークンにヒットしたらデコード完了
53
- if tokenizer.eos_token_id in outputs:
54
- eos_index = outputs.index(tokenizer.eos_token_id)
55
- decoded = tokenizer.decode(outputs[:eos_index])
56
-
 
 
 
 
 
 
 
 
57
  # レスポンス内容のみ抽出
58
  sentinel = "### 回答:"
59
  sentinelLoc = decoded.find(sentinel)
60
  if sentinelLoc >= 0:
61
  result = decoded[sentinelLoc + len(sentinel):]
62
- return result.replace("<NL>", "\n") # <NL>→改行
63
  else:
64
- return 'Warning: Expected prompt template to be emitted. Ignoring output.'
65
- else:
66
- return 'Warning: no <eos> detected ignoring output'
67
 
68
- def inference(input_text):
69
- return generate2(input_text)
70
 
 
 
71
 
 
72
  iface = gr.Interface(
73
  fn=inference,
74
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
@@ -76,8 +81,10 @@ iface = gr.Interface(
76
  title="JURAN🌺",
77
  description="面接官モデルが回答を生成します。",
78
  api_name="ask",
79
- allow_flagging="never"
80
  )
81
 
82
- if __name__ == "__main__":
83
- iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import spaces
5
  import os
6
+
7
+ # bitsandbytesを無効化
8
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
9
 
10
  model_name = "minoD/JURAN"
11
 
12
+ # モデルのロード(CPUで、bitsandbytesを使わない)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
+ device_map="cpu",
16
+ torch_dtype=torch.float16,
17
+ low_cpu_mem_usage=True, # メモリ効率を改善
18
  )
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
20
 
21
  # プロンプトテンプレートの準備
22
  def generate_prompt(F):
23
+ result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
 
 
 
 
 
 
 
 
24
  result = result.replace('\n', '<NL>')
25
  return result
26
 
27
  # テキスト生成関数の定義
28
+ @spaces.GPU(duration=60) # タイムアウトを60秒に設定
29
  def generate2(F=None, maxTokens=256):
30
+ try:
31
+ # モデルをGPUに転送
32
+ model.to("cuda")
33
+
34
+ # 推論
35
+ prompt = generate_prompt(F)
36
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
37
+
38
+ with torch.no_grad(): # 勾配計算を無効化してメモリ節約
39
+ outputs = model.generate(
40
+ input_ids=input_ids,
41
+ max_new_tokens=maxTokens,
42
+ do_sample=True,
43
+ temperature=0.7,
44
+ top_p=0.75,
45
+ top_k=40,
46
+ no_repeat_ngram_size=2,
47
+ )
48
+
49
+ # CPUに戻す
50
+ model.to("cpu")
51
+ torch.cuda.empty_cache() # GPUメモリをクリア
52
+
53
+ outputs = outputs[0].tolist()
54
+ decoded = tokenizer.decode(outputs)
55
+
56
+ # EOSトークンにヒットしたらデコード完了
57
+ if tokenizer.eos_token_id in outputs:
58
+ eos_index = outputs.index(tokenizer.eos_token_id)
59
+ decoded = tokenizer.decode(outputs[:eos_index])
60
+
61
  # レスポンス内容のみ抽出
62
  sentinel = "### 回答:"
63
  sentinelLoc = decoded.find(sentinel)
64
  if sentinelLoc >= 0:
65
  result = decoded[sentinelLoc + len(sentinel):]
66
+ return result.replace("<NL>", "\n")
67
  else:
68
+ return 'Warning: Expected prompt template to be emitted. Ignoring output.'
 
 
69
 
70
+ except Exception as e:
71
+ return f"エラーが発生しました: {str(e)}"
72
 
73
+ def inference(input_text):
74
+ return generate2(input_text)
75
 
76
+ # Gradioインターフェース
77
  iface = gr.Interface(
78
  fn=inference,
79
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
 
81
  title="JURAN🌺",
82
  description="面接官モデルが回答を生成します。",
83
  api_name="ask",
84
+ flagging_mode="never"
85
  )
86
 
87
+ iface.launch(
88
+ server_name="0.0.0.0",
89
+ server_port=7860
90
+ )