QAway-to commited on
Commit
a8bca78
·
1 Parent(s): 8d4e786

f3nsmart/TinyLlama-MBTI-Interviewer-LoRA. v2.0

Browse files
Files changed (6) hide show
  1. app.py +15 -76
  2. core/__init__.py +0 -0
  3. core/interviewer.py +49 -0
  4. core/mbti_analyzer.py +11 -0
  5. core/memory.py +16 -0
  6. core/utils.py +20 -0
app.py CHANGED
@@ -1,96 +1,35 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoTokenizer,
5
- AutoModelForCausalLM,
6
- AutoModelForSequenceClassification,
7
- pipeline
8
- )
9
- from peft import PeftModel # 👈 важно для LoRA адаптации
10
 
11
  # ===============================================================
12
- # 1️⃣ Настройки и модели
13
- # ===============================================================
14
- MBTI_MODEL = "f3nsmart/MBTIclassifier"
15
- INTERVIEWER_BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
- INTERVIEWER_LORA = "f3nsmart/TinyLlama-MBTI-Interviewer-LoRA"
17
-
18
- # --- MBTI классификатор ---
19
- mbti_pipe = pipeline("text-classification", model=MBTI_MODEL, return_all_scores=True)
20
-
21
- # --- Интервьюер TinyLlama + LoRA ---
22
- print("🔄 Загрузка TinyLlama с адаптером LoRA...")
23
- tokenizer_llama = AutoTokenizer.from_pretrained(INTERVIEWER_LORA)
24
- base_model = AutoModelForCausalLM.from_pretrained(
25
- INTERVIEWER_BASE,
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- device_map="auto"
28
- )
29
- model_llora = PeftModel.from_pretrained(base_model, INTERVIEWER_LORA)
30
-
31
- llm_pipe = pipeline(
32
- "text-generation",
33
- model=model_llora,
34
- tokenizer=tokenizer_llama,
35
- max_new_tokens=70,
36
- temperature=0.7,
37
- top_p=0.9,
38
- device_map="auto"
39
- )
40
-
41
- # ===============================================================
42
- # 2️⃣ Вспомогательные функции
43
  # ===============================================================
44
- def clean_question(text: str) -> str:
45
- text = text.strip().split("\n")[0].strip('"').strip("'")
46
- bad_tokens = ["user:", "assistant:", "instruction", "interviewer", "system:"]
47
- for bad in bad_tokens:
48
- if bad.lower() in text.lower():
49
- text = text.split(bad)[-1].strip()
50
- if "?" not in text:
51
- text = text.rstrip(".") + "?"
52
- if len(text.split()) < 3:
53
- return "What do you usually enjoy doing in your free time?"
54
- return text.strip()
55
-
56
- def generate_first_question():
57
- return "What do you usually enjoy doing in your free time?"
58
-
59
- def analyze_and_ask(user_text, prev_count):
60
  if not user_text.strip():
61
  return "⚠️ Введите ответ.", "", prev_count
 
62
  try:
63
  n = int(prev_count.split("/")[0]) + 1
64
  except Exception:
65
  n = 1
66
  counter = f"{n}/30"
67
 
68
- res = mbti_pipe(user_text)[0]
69
- res_sorted = sorted(res, key=lambda x: x["score"], reverse=True)
70
- mbti_text = "\n".join([f"{r['label']} → {r['score']:.3f}" for r in res_sorted[:3]])
71
 
72
- prompt = (
73
- f"User said: '{user_text}'. "
74
- "Generate one natural, open-ended question that starts with 'What', 'Why', 'How', or 'When'. "
75
- "Avoid rephrasing or quoting the user's text. "
76
- "Do NOT explain what you are doing or include any instructions. "
77
- "Output only the question itself."
78
- )
79
 
80
- raw = llm_pipe(prompt)[0]["generated_text"]
81
- cleaned = clean_question(raw)
82
- if not cleaned.startswith(("What", "Why", "How", "When")):
83
- cleaned = "What motivates you to do the things you enjoy most?"
84
- return mbti_text, cleaned, counter
85
 
86
- # ===============================================================
87
- # 3️⃣ Интерфейс Gradio
88
- # ===============================================================
89
  with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as demo:
90
  gr.Markdown(
91
  "## 🧠 MBTI Personality Interviewer\n"
92
  "Определи личностный тип и получи следующий вопрос от интервьюера."
93
  )
 
94
  with gr.Row():
95
  with gr.Column(scale=1):
96
  inp = gr.Textbox(
@@ -104,7 +43,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as
104
  interviewer_out = gr.Textbox(label="💬 Следующий вопрос от интервьюера", lines=3)
105
  progress = gr.Textbox(label="⏳ Прогресс", value="0/30")
106
 
107
- btn.click(analyze_and_ask, inputs=[inp, progress], outputs=[mbti_out, interviewer_out, progress])
108
- demo.load(lambda: ("", generate_first_question(), "0/30"), inputs=None, outputs=[mbti_out, interviewer_out, progress])
109
 
110
- demo.launch()
 
1
  import gradio as gr
2
+ import asyncio
3
+ from core.utils import generate_first_question
4
+ from core.mbti_analyzer import analyze_mbti
5
+ from core.interviewer import generate_question
 
 
 
 
6
 
7
  # ===============================================================
8
+ # 3️⃣ Интерфейс Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # ===============================================================
10
+ async def analyze_and_ask_async(user_text, prev_count, user_id="default_user"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  if not user_text.strip():
12
  return "⚠️ Введите ответ.", "", prev_count
13
+
14
  try:
15
  n = int(prev_count.split("/")[0]) + 1
16
  except Exception:
17
  n = 1
18
  counter = f"{n}/30"
19
 
20
+ mbti_task = asyncio.create_task(analyze_mbti(user_text))
21
+ interviewer_task = asyncio.create_task(generate_question(user_id, user_text))
 
22
 
23
+ mbti_text, next_question = await asyncio.gather(mbti_task, interviewer_task)
24
+ return mbti_text, next_question, counter
 
 
 
 
 
25
 
 
 
 
 
 
26
 
 
 
 
27
  with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as demo:
28
  gr.Markdown(
29
  "## 🧠 MBTI Personality Interviewer\n"
30
  "Определи личностный тип и получи следующий вопрос от интервьюера."
31
  )
32
+
33
  with gr.Row():
34
  with gr.Column(scale=1):
35
  inp = gr.Textbox(
 
43
  interviewer_out = gr.Textbox(label="💬 Следующий вопрос от интервьюера", lines=3)
44
  progress = gr.Textbox(label="⏳ Прогресс", value="0/30")
45
 
46
+ btn.click(analyze_and_ask_async, inputs=[inp, progress], outputs=[mbti_out, interviewer_out, progress])
47
+ demo.load(lambda: ("", generate_first_question(), "0/30"), None, [mbti_out, interviewer_out, progress])
48
 
49
+ demo.queue(concurrency_count=2).launch()
core/__init__.py ADDED
File without changes
core/interviewer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, asyncio
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ from peft import PeftModel
4
+ from core.utils import clean_question
5
+ from core.memory import update_user_context, get_user_context, was_asked
6
+
7
+ INTERVIEWER_BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+ INTERVIEWER_LORA = "f3nsmart/TinyLlama-MBTI-Interviewer-LoRA"
9
+
10
+ print("🔄 Loading interviewer (TinyLlama + LoRA)...")
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(INTERVIEWER_LORA)
13
+ base_model = AutoModelForCausalLM.from_pretrained(
14
+ INTERVIEWER_BASE,
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
16
+ device_map="auto"
17
+ )
18
+ model = PeftModel.from_pretrained(base_model, INTERVIEWER_LORA)
19
+
20
+ llm_pipe = pipeline(
21
+ "text-generation",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ max_new_tokens=70,
25
+ temperature=0.7,
26
+ top_p=0.9,
27
+ device_map="auto"
28
+ )
29
+
30
+
31
+ async def generate_question(user_id: str, user_text: str) -> str:
32
+ """Асинхронная генерация вопроса"""
33
+ history = get_user_context(user_id)
34
+ prev_qs = " | ".join(history["questions"][-5:]) # последние 5 вопросов
35
+
36
+ prompt = (
37
+ f"User said: '{user_text}'. Previous questions: {prev_qs or 'None'}. "
38
+ "Generate one natural, open-ended question starting with 'What', 'Why', 'How', or 'When'. "
39
+ "Avoid repeating or rephrasing previous questions. "
40
+ "Output only the question itself."
41
+ )
42
+ loop = asyncio.get_event_loop()
43
+ result = await loop.run_in_executor(None, lambda: llm_pipe(prompt)[0]["generated_text"])
44
+ cleaned = clean_question(result)
45
+
46
+ if was_asked(user_id, cleaned):
47
+ cleaned = "What new challenges have you faced recently?"
48
+ update_user_context(user_id, cleaned, user_text)
49
+ return cleaned
core/mbti_analyzer.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from core.utils import format_mbti_output
3
+
4
+ MBTI_MODEL = "f3nsmart/MBTIclassifier"
5
+ mbti_pipe = pipeline("text-classification", model=MBTI_MODEL, return_all_scores=True)
6
+
7
+ async def analyze_mbti(text: str) -> str:
8
+ """Асинхронный анализ MBTI"""
9
+ loop = __import__("asyncio").get_event_loop()
10
+ result = await loop.run_in_executor(None, mbti_pipe, text)
11
+ return format_mbti_output(result[0])
core/memory.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ user_memory = {}
2
+
3
+ def get_user_context(user_id: str):
4
+ """Возвращает историю вопросов и ответов для пользователя"""
5
+ return user_memory.get(user_id, {"questions": [], "answers": []})
6
+
7
+ def update_user_context(user_id: str, question: str, answer: str):
8
+ ctx = user_memory.setdefault(user_id, {"questions": [], "answers": []})
9
+ ctx["questions"].append(question)
10
+ ctx["answers"].append(answer)
11
+ return ctx
12
+
13
+ def was_asked(user_id: str, new_question: str) -> bool:
14
+ """Проверяет, повторялся ли вопрос"""
15
+ ctx = get_user_context(user_id)
16
+ return new_question.strip().lower() in [q.lower() for q in ctx["questions"]]
core/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def clean_question(text: str) -> str:
2
+ text = text.strip().split("\n")[0].strip('"').strip("'")
3
+ bad_tokens = ["user:", "assistant:", "instruction", "interviewer", "system:"]
4
+ for bad in bad_tokens:
5
+ if bad.lower() in text.lower():
6
+ text = text.split(bad)[-1].strip()
7
+ if "?" not in text:
8
+ text = text.rstrip(".") + "?"
9
+ if len(text.split()) < 3:
10
+ return "What do you usually enjoy doing in your free time?"
11
+ return text.strip()
12
+
13
+
14
+ def generate_first_question():
15
+ return "What do you usually enjoy doing in your free time?"
16
+
17
+
18
+ def format_mbti_output(res):
19
+ res_sorted = sorted(res, key=lambda x: x["score"], reverse=True)
20
+ return "\n".join([f"{r['label']} → {r['score']:.3f}" for r in res_sorted[:3]])