MBTI / core /interviewer.py
QAway-to
Change tokenizer v1.0
dc0b1bc
raw
history blame
2.77 kB
# core/interviewer.py
"""
🇬🇧 Interviewer logic module (no instructions)
Generates random MBTI-style questions using a fine-tuned model.
🇷🇺 Модуль интервьюера.
Использует fine-tuned модель для генерации вопросов без промптов и инструкций.
"""
import random, torch, re
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer
# --------------------------------------------------------------
# 1️⃣ Настройки модели
# --------------------------------------------------------------
QG_MODEL = "f3nsmart/ft-flan-t5-base-qgen"
# ❗ Используем "slow" SentencePiece токенайзер
tokenizer = T5Tokenizer.from_pretrained(QG_MODEL, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(QG_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
print(f"✅ Loaded interviewer model (slow tokenizer): {QG_MODEL}")
# --------------------------------------------------------------
# 2️⃣ Базовые seed-промпты (без инструкций)
# --------------------------------------------------------------
PROMPTS = [
"Personality and emotions.",
"Human motivation and choices.",
"Self-awareness and reflection.",
"Personal growth and behavior.",
"How people make decisions.",
]
# --------------------------------------------------------------
# 3️⃣ Очистка текста
# --------------------------------------------------------------
def _clean_question(text: str) -> str:
"""Берёт первую фразу с '?', обрезает лишнее"""
text = text.strip()
m = re.search(r"(.+?\?)", text)
if m:
text = m.group(1)
text = text.replace("\n", " ").strip()
if len(text.split()) < 3:
text = text.capitalize()
if not text.endswith("?"):
text += "?"
return text
# --------------------------------------------------------------
# 4️⃣ Генерация вопроса
# --------------------------------------------------------------
def generate_question(user_id: str = "default_user", **kwargs) -> str:
"""
Генерирует один MBTI-вопрос без инструкций.
"""
prompt = random.choice(PROMPTS)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
with torch.no_grad():
out = model.generate(
**inputs,
do_sample=True,
top_p=0.9,
temperature=0.9,
repetition_penalty=1.1,
max_new_tokens=60,
)
text = tokenizer.decode(out[0], skip_special_tokens=True)
question = _clean_question(text)
return question