File size: 2,454 Bytes
ba34bf3
 
 
d5b51f1
ba34bf3
 
 
377a726
ba34bf3
 
d5b51f1
ce080b8
ba34bf3
 
 
 
 
377a726
ba34bf3
 
 
 
377a726
ba34bf3
377a726
d5b51f1
377a726
ba34bf3
 
377a726
ba34bf3
 
 
 
 
 
 
 
 
377a726
 
 
 
ba34bf3
 
377a726
ba34bf3
 
 
 
 
377a726
 
 
 
 
 
 
 
ba34bf3
377a726
 
ba34bf3
377a726
ba34bf3
 
 
 
 
3e839ef
6f624b8
377a726
ba34bf3
 
377a726
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

model_name = "minoD/JURAN"

# モデルのロード(CPUで)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# プロンプトテンプレートの準備
def generate_prompt(F):
    result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
    result = result.replace('\n', '<NL>')
    return result

# テキスト生成関数の定義
@spaces.GPU
def generate2(F=None, maxTokens=256):
    # モデルをGPUに転送
    model.to("cuda")
    
    # 推論
    prompt = generate_prompt(F)
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=maxTokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.75,
        top_k=40,
        no_repeat_ngram_size=2,
    )
    
    # CPUに戻す
    model.to("cpu")
    
    outputs = outputs[0].tolist()
    decoded = tokenizer.decode(outputs)
    
    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])
    
    # レスポンス内容のみ抽出
    sentinel = "### 回答:"
    sentinelLoc = decoded.find(sentinel)
    if sentinelLoc >= 0:
        result = decoded[sentinelLoc + len(sentinel):]
        return result.replace("<NL>", "\n")
    else:
        return 'Warning: Expected prompt template to be emitted. Ignoring output.'

def inference(input_text):
    return generate2(input_text)

# Gradioインターフェース
iface = gr.Interface(
    fn=inference,
    inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
    outputs=gr.Textbox(label="想定される質問"),
    title="JURAN🌺",
    description="面接官モデルが回答を生成します。",
    api_name="ask",
    flagging_mode="never"
)

# if __name__ を削除して直接launch
iface.launch(
    server_name="0.0.0.0",
    server_port=7860
)