chat_test_z0 / app.py
LevinAleksey's picture
Update app.py
bec8c88 verified
import os
import spaces # Должен быть первым
import torch
import gradio as gr
from qdrant_client import QdrantClient
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import httpx
# Оптимальная модель для ZeroGPU: 32B в 4-битном квантовании
# Она умная, знает русский и не вылетает по памяти
MODEL_ID = "unsloth/Qwen2.5-32B-Instruct-bnb-4bit"
# Инициализация токенайзера (на CPU)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# --- ПОИСК В БАЗЕ (CPU) ---
def get_inventory_context(query):
OR_KEY = os.getenv("OPENROUTER_API_KEY")
Q_URL = os.getenv("QDRANT_URL")
Q_KEY = os.getenv("QDRANT_API_KEY")
try:
# Эмбеддинг через API (чтобы не грузить локально еще и модель векторов)
with httpx.Client(timeout=20.0) as client:
res = client.post(
"https://openrouter.ai/api/v1/embeddings",
headers={"Authorization": f"Bearer {OR_KEY}"},
json={"model": "openai/text-embedding-3-small", "input": query}
)
vector = res.json()["data"][0]["embedding"]
qc = QdrantClient(url=Q_URL, api_key=Q_KEY)
search_results = qc.search(collection_name="equipment_registry", query_vector=vector, limit=5)
return "\n\n".join([r.payload.get("search_text", "") for r in search_results if r.payload])
except Exception as e:
print(f"Search error: {e}")
return "Данные из реестра временно недоступны."
# --- ГЕНЕРАЦИЯ ОТВЕТА (GPU) ---
@spaces.GPU(duration=90) # Резервируем видеокарту на 90 сек
def chat_generate(prompt):
# Загружаем модель прямо на выделенную GPU
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto"
)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
temperature=0.2, # Низкая температура для точности в документах
top_p=0.9
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# Основная функция интерфейса
def respond(message, history, system_message):
# 1. Тянем контекст из Qdrant
context = get_inventory_context(message)
# 2. Формируем промпт в формате Qwen
full_prompt = (
f"<|im_start|>system\n{system_message}\n\n"
f"ИСПОЛЬЗУЙ ЭТИ ДАННЫЕ ИЗ РЕЕСТРА:\n{context}<|im_end|>\n"
f"<|im_start|>user\n{message}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# 3. Стримим ответ с GPU
for part in chat_generate(full_prompt):
yield part
# --- ИНТЕРФЕЙС ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("### 🏢 Система поиска по реестру (QWEN-32B + ZeroGPU)")
gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Ты — ассистент по учету оборудования. Отвечай только на основе предоставленных данных.", label="Инструкция")
]
)
if __name__ == "__main__":
demo.launch()