minoD commited on
Commit
377a726
·
verified ·
1 Parent(s): 4698471

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -1,45 +1,33 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import os
5
- import shutil
6
  import spaces
7
 
8
  model_name = "minoD/JURAN"
9
 
10
- # モデルのロード
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
  device_map="cpu",
14
  torch_dtype=torch.float16
15
-
16
  )
17
-
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19
 
20
  # プロンプトテンプレートの準備
21
  def generate_prompt(F):
22
- # input キーの代わりに Q F を使用
23
- result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.
24
-
25
- ### 質問:
26
- {F}
27
-
28
- ### 回答:
29
- """ # 回答セクションを追加
30
- # 改行→<NL>
31
  result = result.replace('\n', '<NL>')
32
  return result
33
 
34
  # テキスト生成関数の定義
 
35
  def generate2(F=None, maxTokens=256):
 
36
  model.to("cuda")
 
37
  # 推論
38
  prompt = generate_prompt(F)
39
- input_ids = tokenizer(prompt,
40
- return_tensors="pt",
41
- truncation=True,
42
- add_special_tokens=False).input_ids.to(model.device)
43
  outputs = model.generate(
44
  input_ids=input_ids,
45
  max_new_tokens=maxTokens,
@@ -49,29 +37,31 @@ def generate2(F=None, maxTokens=256):
49
  top_k=40,
50
  no_repeat_ngram_size=2,
51
  )
 
 
 
 
52
  outputs = outputs[0].tolist()
53
  decoded = tokenizer.decode(outputs)
54
-
55
  # EOSトークンにヒットしたらデコード完了
56
  if tokenizer.eos_token_id in outputs:
57
  eos_index = outputs.index(tokenizer.eos_token_id)
58
  decoded = tokenizer.decode(outputs[:eos_index])
59
-
60
- # レスポンス内容のみ抽出
61
- sentinel = "### 回答:"
62
- sentinelLoc = decoded.find(sentinel)
63
- if sentinelLoc >= 0:
64
- result = decoded[sentinelLoc + len(sentinel):]
65
- return result.replace("<NL>", "\n") # <NL>→改行
66
- else:
67
- return 'Warning: Expected prompt template to be emitted. Ignoring output.'
68
- else:
69
- return 'Warning: no <eos> detected ignoring output'
70
 
71
- def inference(input_text):
72
- return generate2(input_text)
 
 
 
 
 
 
73
 
 
 
74
 
 
75
  iface = gr.Interface(
76
  fn=inference,
77
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
@@ -79,8 +69,11 @@ iface = gr.Interface(
79
  title="JURAN🌺",
80
  description="面接官モデルが回答を生成します。",
81
  api_name="ask",
82
- allow_flagging="never"
83
  )
84
 
85
- if __name__ == "__main__":
86
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
  import spaces
5
 
6
  model_name = "minoD/JURAN"
7
 
8
+ # モデルのロード(CPUで)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  device_map="cpu",
12
  torch_dtype=torch.float16
 
13
  )
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
15
 
16
  # プロンプトテンプレートの準備
17
  def generate_prompt(F):
18
+ result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
 
 
 
 
 
 
 
 
19
  result = result.replace('\n', '<NL>')
20
  return result
21
 
22
  # テキスト生成関数の定義
23
+ @spaces.GPU
24
  def generate2(F=None, maxTokens=256):
25
+ # モデルをGPUに転送
26
  model.to("cuda")
27
+
28
  # 推論
29
  prompt = generate_prompt(F)
30
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
 
 
 
31
  outputs = model.generate(
32
  input_ids=input_ids,
33
  max_new_tokens=maxTokens,
 
37
  top_k=40,
38
  no_repeat_ngram_size=2,
39
  )
40
+
41
+ # CPUに戻す
42
+ model.to("cpu")
43
+
44
  outputs = outputs[0].tolist()
45
  decoded = tokenizer.decode(outputs)
46
+
47
  # EOSトークンにヒットしたらデコード完了
48
  if tokenizer.eos_token_id in outputs:
49
  eos_index = outputs.index(tokenizer.eos_token_id)
50
  decoded = tokenizer.decode(outputs[:eos_index])
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # レスポンス内容のみ抽出
53
+ sentinel = "### 回答:"
54
+ sentinelLoc = decoded.find(sentinel)
55
+ if sentinelLoc >= 0:
56
+ result = decoded[sentinelLoc + len(sentinel):]
57
+ return result.replace("<NL>", "\n")
58
+ else:
59
+ return 'Warning: Expected prompt template to be emitted. Ignoring output.'
60
 
61
+ def inference(input_text):
62
+ return generate2(input_text)
63
 
64
+ # Gradioインターフェース
65
  iface = gr.Interface(
66
  fn=inference,
67
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
 
69
  title="JURAN🌺",
70
  description="面接官モデルが回答を生成します。",
71
  api_name="ask",
72
+ flagging_mode="never"
73
  )
74
 
75
+ # if __name__ を削除して直接launch
76
+ iface.launch(
77
+ server_name="0.0.0.0",
78
+ server_port=7860
79
+ )