File size: 2,454 Bytes
ba34bf3 d5b51f1 ba34bf3 377a726 ba34bf3 d5b51f1 ce080b8 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 377a726 d5b51f1 377a726 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 377a726 ba34bf3 3e839ef 6f624b8 377a726 ba34bf3 377a726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
model_name = "minoD/JURAN"
# モデルのロード(CPUで)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# プロンプトテンプレートの準備
def generate_prompt(F):
result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
result = result.replace('\n', '<NL>')
return result
# テキスト生成関数の定義
@spaces.GPU
def generate2(F=None, maxTokens=256):
# モデルをGPUに転送
model.to("cuda")
# 推論
prompt = generate_prompt(F)
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=maxTokens,
do_sample=True,
temperature=0.7,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
# CPUに戻す
model.to("cpu")
outputs = outputs[0].tolist()
decoded = tokenizer.decode(outputs)
# EOSトークンにヒットしたらデコード完了
if tokenizer.eos_token_id in outputs:
eos_index = outputs.index(tokenizer.eos_token_id)
decoded = tokenizer.decode(outputs[:eos_index])
# レスポンス内容のみ抽出
sentinel = "### 回答:"
sentinelLoc = decoded.find(sentinel)
if sentinelLoc >= 0:
result = decoded[sentinelLoc + len(sentinel):]
return result.replace("<NL>", "\n")
else:
return 'Warning: Expected prompt template to be emitted. Ignoring output.'
def inference(input_text):
return generate2(input_text)
# Gradioインターフェース
iface = gr.Interface(
fn=inference,
inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
outputs=gr.Textbox(label="想定される質問"),
title="JURAN🌺",
description="面接官モデルが回答を生成します。",
api_name="ask",
flagging_mode="never"
)
# if __name__ を削除して直接launch
iface.launch(
server_name="0.0.0.0",
server_port=7860
) |