File size: 6,250 Bytes
3297f8d 7327516 3297f8d 2f2eda6 3297f8d 2f2eda6 3297f8d 166f868 2f2eda6 3297f8d 2f2eda6 3297f8d 2f2eda6 3297f8d 2f2eda6 166f868 2f2eda6 7327516 3297f8d 7327516 3297f8d 7327516 3297f8d 7327516 2f2eda6 3297f8d 2f2eda6 3297f8d 7327516 2f2eda6 7327516 2f2eda6 7327516 fcf71d0 7327516 2f2eda6 7327516 2f2eda6 15de1d7 2f2eda6 15de1d7 4ab7e58 3297f8d 15de1d7 2f2eda6 7327516 2f2eda6 15de1d7 3297f8d 2f2eda6 fcf71d0 2f2eda6 3297f8d 2f2eda6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import torch
import gradio as gr
import requests
from typing import List, Dict, Iterator
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# --- 1) Конфигурация и загрузка модели ---
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
ADAPTER_ID = os.getenv("ADAPTER_ID")
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.9
REPETITION_PENALTY = 1.05
SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә дә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү. Кулланучыга, фактлардан һәм саннардан качып, һәрвакыт кыска җавап бирергә кирәк"
)
print("Загрузка модели с 4-битной квантизацией...")
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=quantization_config,
device_map="auto"
)
print("Применяем LoRA адаптер...")
model = PeftModel.from_pretrained(base, ADAPTER_ID)
model.config.use_cache = False
model.eval()
print("✅ Модель успешно загружена!")
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
def detect_language(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
try:
resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
resp.raise_for_status()
return resp.json().get("languageCode", "ru")
except requests.exceptions.RequestException:
return "ru"
def ru2tt(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"}
try:
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
return resp.json()["translations"][0]["text"]
except requests.exceptions.RequestException:
return text
def render_prompt(messages: List[Dict[str, str]]) -> str:
return tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# --- 4) Стриминговая генерация (без тримминга) ---
@torch.inference_mode()
def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
prompt = render_prompt(messages)
enc = tok(prompt, return_tensors="pt")
enc = {k: v.to(model.device) for k, v in enc.items()}
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**enc,
streamer=streamer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
acc = ""
for chunk in streamer:
acc += chunk
yield acc
def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]):
if not messages_state or messages_state[0].get("role") != "system":
messages_state = [{"role": "system", "content": SYS_PROMPT_TT}]
detected = detect_language(message)
user_tt = ru2tt(message) if detected != "tt" else message
messages = messages_state + [{"role": "user", "content": user_tt}]
ui_history = ui_history + [[user_tt, ""]]
for partial in generate_tt_reply_stream(messages):
ui_history[-1][1] = partial
yield ui_history, messages_state + [
{"role": "user", "content": user_tt},
{"role": "assistant", "content": partial},
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Татарский чат-бот от команды Сбера")
messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}])
chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False)
msg = gr.Textbox(
label="Хәбәрегезне рус яки татар телендә языгыз",
placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?"
)
clear = gr.Button("🗑️ Чистарту")
msg.submit(
chat_fn,
inputs=[msg, chatbot, messages_state],
outputs=[chatbot, messages_state],
)
msg.submit(lambda: "", None, msg)
def _reset():
return [], [{"role": "system", "content": SYS_PROMPT_TT}]
clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False)
print("Messages_state: " + messages_state)
clear.click(lambda: "", None, msg, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|