Neweret commited on
Commit
1f23750
·
verified ·
1 Parent(s): 519e6e8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -1
README.md CHANGED
@@ -12,4 +12,99 @@ tags:
12
  - Prompt Classes
13
  - Classificator
14
  - Prompt Classification
15
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  - Prompt Classes
13
  - Classificator
14
  - Prompt Classification
15
+ ---
16
+
17
+ # 🔁 SimplePromptRouter — классификатор промптов (русский)
18
+
19
+ **Кратко:** модель классифицирует входные промпты/вопросы на три действия:
20
+ - **0 — Поиск в локальной базе знаний (RAG)**: сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации.
21
+ - **1 — Поиск в сети**: триггер запуска обхода внешних поисковых систем/скрейпинга.
22
+ - **2 — Прямой запрос**: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа.
23
+
24
+ ---
25
+
26
+ ## Где используется
27
+ Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта:
28
+ - чат-боты со связкой Retrieval-Augmented Generation (RAG),
29
+ - голосовые ассистенты,
30
+ - интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией.
31
+
32
+ ---
33
+
34
+ ## Файлы в репозитории
35
+ - `pytorch_model.bin` — веса модели (state_dict).
36
+ - `config.json` — конфигурация (input_dim, num_classes, p_dropout, classes).
37
+ - `modeling_simple_classifier.py` или `model.py` — определение архитектуры (если кастом).
38
+ - `vectorizer.pkl` — sklearn-векторизатор (TF-IDF/Count).
39
+ - `svd.pkl` — TruncatedSVD (опционально).
40
+ - `label_encoder.pkl` — sklearn.LabelEncoder (для декодирования метки).
41
+ - `README.md` — эта карточка.
42
+
43
+ ---
44
+
45
+ ## Пример загрузки и инференса (без AutoModel)
46
+
47
+ ```python
48
+ # Пример: загрузка напрямую из репозитория HF (не требует локальной копии)
49
+ from huggingface_hub import hf_hub_download
50
+ import json, pickle, torch
51
+ import numpy as np
52
+ from types import SimpleNamespace
53
+
54
+ REPO = "username/SimplePromptRouter"
55
+
56
+ config_path = hf_hub_download(REPO, "config.json")
57
+ weights_path = hf_hub_download(REPO, "pytorch_model.bin")
58
+ vec_path = hf_hub_download(REPO, "vectorizer.pkl")
59
+ svd_path = None
60
+ try:
61
+ svd_path = hf_hub_download(REPO, "svd.pkl")
62
+ except Exception:
63
+ svd_path = None
64
+ le_path = hf_hub_download(REPO, "label_encoder.pkl")
65
+
66
+ cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8")))
67
+
68
+ # --- Динамическая модель (вставь ту же архитектуру, что использовал при обучении) ---
69
+ class SimpleClassifier(torch.nn.Module):
70
+ def __init__(self, input_dim, num_classes, p_dropout=0.3):
71
+ super().__init__()
72
+ self.linear1 = torch.nn.Linear(input_dim, 256)
73
+ self.ln1 = torch.nn.LayerNorm(256)
74
+ self.dropout = torch.nn.Dropout(p_dropout)
75
+ self.linear2 = torch.nn.Linear(256, 128)
76
+ self.ln2 = torch.nn.LayerNorm(128)
77
+ self.linear_out = torch.nn.Linear(128, num_classes)
78
+ def forward(self, x):
79
+ x = torch.nn.functional.gelu(self.ln1(self.linear1(x)))
80
+ x = self.dropout(x)
81
+ x = torch.nn.functional.gelu(self.ln2(self.linear2(x)))
82
+ x = self.dropout(x)
83
+ return self.linear_out(x)
84
+
85
+ model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout)
86
+ state = torch.load(weights_path, map_location="cpu")
87
+ model.load_state_dict(state)
88
+ model.eval()
89
+
90
+ # препроцессинг
91
+ vectorizer = pickle.load(open(vec_path, "rb"))
92
+ svd = pickle.load(open(svd_path, "rb")) if svd_path else None
93
+ le = pickle.load(open(le_path, "rb"))
94
+
95
+ def preprocess(text):
96
+ X = vectorizer.transform([text])
97
+ if svd is not None:
98
+ X = svd.transform(X)
99
+ return X.astype(np.float32)
100
+
101
+ def predict(text):
102
+ x = preprocess(text)
103
+ xb = torch.from_numpy(x).float()
104
+ with torch.inference_mode():
105
+ logits = model(xb)
106
+ pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0])
107
+ return pred, le.inverse_transform([pred])[0]
108
+
109
+ # пример
110
+ print(predict("Как мне найти документацию по OpenAI?"))