File size: 1,597 Bytes
56954f5
1c31761
4c6f761
 
9458365
a31dc30
 
dc0b1bc
87b7e98
34fcc83
 
893dddd
4c6f761
893dddd
4c6f761
 
34fcc83
 
4c6f761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# core/interviewer.py
import torch
import threading
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, TextIteratorStreamer

QG_MODEL = "f3nsmart/ft-flan-t5-base-qgen_v2"

tokenizer = T5Tokenizer.from_pretrained(QG_MODEL, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(QG_MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

print(f"✅ Loaded interviewer model (streaming ready): {QG_MODEL}")

# обычная версия (если нужно fallback)
def generate_question(prompt: str = "Generate one thoughtful question.") -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=80)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# потоковая версия
def stream_question(prompt: str = "Generate one thoughtful question."):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=80,
        do_sample=True,
        top_p=0.9,
        temperature=1.1,
        top_k=60,
        repetition_penalty=1.3,
    )

    # модель работает в отдельном потоке
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    partial = ""
    for new_text in streamer:
        partial += new_text
        yield partial