File size: 1,597 Bytes
56954f5 1c31761 4c6f761 9458365 a31dc30 dc0b1bc 87b7e98 34fcc83 893dddd 4c6f761 893dddd 4c6f761 34fcc83 4c6f761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# core/interviewer.py
import torch
import threading
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, TextIteratorStreamer
QG_MODEL = "f3nsmart/ft-flan-t5-base-qgen_v2"
tokenizer = T5Tokenizer.from_pretrained(QG_MODEL, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(QG_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
print(f"✅ Loaded interviewer model (streaming ready): {QG_MODEL}")
# обычная версия (если нужно fallback)
def generate_question(prompt: str = "Generate one thoughtful question.") -> str:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=80)
return tokenizer.decode(output[0], skip_special_tokens=True)
# потоковая версия
def stream_question(prompt: str = "Generate one thoughtful question."):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=80,
do_sample=True,
top_p=0.9,
temperature=1.1,
top_k=60,
repetition_penalty=1.3,
)
# модель работает в отдельном потоке
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial = ""
for new_text in streamer:
partial += new_text
yield partial
|