minoD commited on
Commit
4449604
·
verified ·
1 Parent(s): 5d20ce5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -31
app.py CHANGED
@@ -1,33 +1,42 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import spaces
 
5
 
6
  model_name = "minoD/JURAN"
7
 
8
- # モデルのロード(CPUで)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
- device_map="cpu",
12
  torch_dtype=torch.float16
13
  )
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
15
 
16
  # プロンプトテンプレートの準備
17
  def generate_prompt(F):
18
- result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
 
 
 
 
 
 
 
 
19
  result = result.replace('\n', '<NL>')
20
  return result
21
 
22
  # テキスト生成関数の定義
23
- @spaces.GPU
24
  def generate2(F=None, maxTokens=256):
25
- # モデルをGPUに転送
26
- model.to("cuda")
27
-
28
  # 推論
29
  prompt = generate_prompt(F)
30
- input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
 
 
 
31
  outputs = model.generate(
32
  input_ids=input_ids,
33
  max_new_tokens=maxTokens,
@@ -37,31 +46,29 @@ def generate2(F=None, maxTokens=256):
37
  top_k=40,
38
  no_repeat_ngram_size=2,
39
  )
40
-
41
- # CPUに戻す
42
- model.to("cpu")
43
-
44
  outputs = outputs[0].tolist()
45
  decoded = tokenizer.decode(outputs)
46
-
47
  # EOSトークンにヒットしたらデコード完了
48
  if tokenizer.eos_token_id in outputs:
49
  eos_index = outputs.index(tokenizer.eos_token_id)
50
  decoded = tokenizer.decode(outputs[:eos_index])
51
-
52
- # レスポンス内容のみ抽出
53
- sentinel = "### 回答:"
54
- sentinelLoc = decoded.find(sentinel)
55
- if sentinelLoc >= 0:
56
- result = decoded[sentinelLoc + len(sentinel):]
57
- return result.replace("<NL>", "\n")
58
- else:
59
- return 'Warning: Expected prompt template to be emitted. Ignoring output.'
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  def inference(input_text):
62
- return generate2(input_text)
 
63
 
64
- # Gradioインターフェース
65
  iface = gr.Interface(
66
  fn=inference,
67
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
@@ -69,11 +76,8 @@ iface = gr.Interface(
69
  title="JURAN🌺",
70
  description="面接官モデルが回答を生成します。",
71
  api_name="ask",
72
- flagging_mode="never"
73
  )
74
 
75
- # if __name__ を削除して直接launch
76
- iface.launch(
77
- server_name="0.0.0.0",
78
- server_port=7860
79
- )
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import os
5
+ import shutil
6
 
7
  model_name = "minoD/JURAN"
8
 
9
+ # モデルのロード
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_name,
12
+ device_map="auto",
13
  torch_dtype=torch.float16
14
  )
15
+
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
17
 
18
  # プロンプトテンプレートの準備
19
  def generate_prompt(F):
20
+ # input キーの代わりに Q F を使用
21
+ result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.
22
+
23
+ ### 質問:
24
+ {F}
25
+
26
+ ### 回答:
27
+ """ # 回答セクションを追加
28
+ # 改行→<NL>
29
  result = result.replace('\n', '<NL>')
30
  return result
31
 
32
  # テキスト生成関数の定義
 
33
  def generate2(F=None, maxTokens=256):
 
 
 
34
  # 推論
35
  prompt = generate_prompt(F)
36
+ input_ids = tokenizer(prompt,
37
+ return_tensors="pt",
38
+ truncation=True,
39
+ add_special_tokens=False).input_ids.to(model.device)
40
  outputs = model.generate(
41
  input_ids=input_ids,
42
  max_new_tokens=maxTokens,
 
46
  top_k=40,
47
  no_repeat_ngram_size=2,
48
  )
 
 
 
 
49
  outputs = outputs[0].tolist()
50
  decoded = tokenizer.decode(outputs)
51
+
52
  # EOSトークンにヒットしたらデコード完了
53
  if tokenizer.eos_token_id in outputs:
54
  eos_index = outputs.index(tokenizer.eos_token_id)
55
  decoded = tokenizer.decode(outputs[:eos_index])
 
 
 
 
 
 
 
 
 
56
 
57
+ # レスポンス内容のみ抽出
58
+ sentinel = "### 回答:"
59
+ sentinelLoc = decoded.find(sentinel)
60
+ if sentinelLoc >= 0:
61
+ result = decoded[sentinelLoc + len(sentinel):]
62
+ return result.replace("<NL>", "\n") # <NL>→改行
63
+ else:
64
+ return 'Warning: Expected prompt template to be emitted. Ignoring output.'
65
+ else:
66
+ return 'Warning: no <eos> detected ignoring output'
67
+
68
  def inference(input_text):
69
+ return generate2(input_text)
70
+
71
 
 
72
  iface = gr.Interface(
73
  fn=inference,
74
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
 
76
  title="JURAN🌺",
77
  description="面接官モデルが回答を生成します。",
78
  api_name="ask",
79
+ allow_flagging="never"
80
  )
81
 
82
+ if __name__ == "__main__":
83
+ iface.launch(share=True)