import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces import os os.environ["BITSANDBYTES_NOWELCOME"] = "1" model_name = "minoD/JURAN" # モデルのロード model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", torch_dtype=torch.float16, low_cpu_mem_usage=True, ) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # ウォームアップフラグ warmup_done = False def generate_prompt(F): result = f"""### 指示: あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。 ### エントリーシート: {F} ### 面接官の質問:""" result = result.replace('\n', '') return result @spaces.GPU(duration=60) def warmup_model(): """モデルのウォームアップ処理""" global warmup_done if not warmup_done: print("ウォームアップ中...") model.to("cuda") # ダミー推論を実行 dummy_input = tokenizer("テスト", return_tensors="pt").input_ids.to("cuda") with torch.no_grad(): _ = model.generate( dummy_input, max_new_tokens=10, do_sample=False ) model.to("cpu") torch.cuda.empty_cache() warmup_done = True print("ウォームアップ完了") @spaces.GPU(duration=60) def generate2(F=None, maxTokens=256): try: # ウォームアップ(初回のみ) if not warmup_done: warmup_model() # 乱数シードを固定(オプション) torch.manual_seed(42) model.to("cuda") prompt = generate_prompt(F) input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda") with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=maxTokens, do_sample=True, temperature=0.7, top_p=0.75, top_k=40, no_repeat_ngram_size=2, ) model.to("cpu") torch.cuda.empty_cache() outputs = outputs[0].tolist() decoded = tokenizer.decode(outputs) if tokenizer.eos_token_id in outputs: eos_index = outputs.index(tokenizer.eos_token_id) decoded = tokenizer.decode(outputs[:eos_index]) sentinel = "### 面接官の質問:" sentinelLoc = decoded.find(sentinel) if sentinelLoc >= 0: result = decoded[sentinelLoc + len(sentinel):] result = result.split('\n')[0] if '\n' in result else result return result.replace("", "\n").strip() else: return 'Warning: Expected prompt template to be emitted. Ignoring output.' except Exception as e: return f"エラーが発生しました: {str(e)}" def inference(input_text): return generate2(input_text) iface = gr.Interface( fn=inference, inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"), outputs=gr.Textbox(label="想定される質問"), title="JURAN🌺", description="面接官モデルが回答を生成します。", api_name="ask", flagging_mode="never" ) iface.launch( server_name="0.0.0.0", server_port=7860 )