File size: 2,071 Bytes
b472457
 
db72f77
c47d9e0
b6fe054
97e1b0e
 
 
 
 
c47d9e0
b6fe054
97e1b0e
b472457
 
97e1b0e
b472457
97e1b0e
b9e73f2
 
268ce6d
97e1b0e
b9e73f2
97e1b0e
 
 
 
 
b472457
 
97e1b0e
 
b472457
 
97e1b0e
b472457
b9e73f2
 
db72f77
b472457
97e1b0e
 
b6fe054
97e1b0e
 
 
b9e73f2
97e1b0e
268ce6d
97e1b0e
db72f77
 
97e1b0e
 
db72f77
97e1b0e
b6fe054
 
 
de48c90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Параметры
model_repo = "cody82/unitrip"
cache_dir = "/data/model"  # Persistent storage путь на HF Spaces

# Создаём каталог, если не существует
os.makedirs(cache_dir, exist_ok=True)

# Загрузка модели и токенизатора
tokenizer = AutoTokenizer.from_pretrained(model_repo, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_repo, cache_dir=cache_dir)
model.to("cpu")  # Используем CPU, так как у нас ZeroGPU

@spaces.gpu
def respond(message, history):
    history = history or []

    # Формируем текст истории
    full_input = ""
    for turn in history:
        if turn["role"] == "user":
            full_input += f"User: {turn['content']}\n"
        elif turn["role"] == "assistant":
            full_input += f"Assistant: {turn['content']}\n"
    full_input += f"User: {message}\nAssistant:"

    # Токенизация и генерация
    inputs = tokenizer(full_input, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id,
    )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = decoded.split("Assistant:")[-1].strip()

    # Обновление истории
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": response})

    return history

# Интерфейс
chat = gr.ChatInterface(
    fn=respond,
    chatbot=gr.Chatbot(label="Unitrip Assistant", type="messages"),
    title="Unitrip Travel Assistant",
    theme="soft",
    examples=["Какие города ты рекомендуешь посетить в Италии?", "Лучшее время для поездки в Японию?"],
)

if __name__ == "__main__":
    chat.launch()