File size: 3,384 Bytes
6275495
e6961b0
6275495
 
 
983eb46
 
af4b77b
983eb46
db606bb
983eb46
b69b644
b10ba12
80ceb8c
6275495
9c601ea
 
9d9c29a
db606bb
 
 
 
 
af5c917
983eb46
 
db606bb
af4b77b
c263659
db606bb
 
84956d9
db606bb
5e09a54
 
24fb973
c263659
af4b77b
52ae47c
af4b77b
52ae47c
 
5e09a54
52ae47c
af4b77b
983eb46
 
4edcafa
84956d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4edcafa
84956d9
4edcafa
84956d9
 
144f920
84956d9
 
4edcafa
144f920
4edcafa
 
144f920
 
4edcafa
 
144f920
4edcafa
144f920
 
4edcafa
983eb46
 
84956d9
983eb46
 
 
 
 
 
 
 
 
 
 
 
 
 
000988a
3bb8597
144f920
983eb46
 
3bb8597
4edcafa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from fastapi import FastAPI, Request
import uvicorn
from fastapi.middleware.cors import CORSMiddleware

# === Модель ===
model_id = "cody82/unitrip"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

context = (
    "Университет Иннополис был основан в 2012 году. "
    "Это современный вуз в России, специализирующийся на IT и робототехнике, "
    "расположенный в городе Иннополис, Татарстан.\n"
)

def generate_response(question):
    prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {question}\nОтвет:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=150,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if "Ответ:" in output:
        answer = output.split("Ответ:")[-1].strip()
    else:
        answer = output[len(prompt):].strip()

    return answer


# === Gradio UI ===
custom_css = """
.gradio-container {
    max-width: 900px !important;
    margin: auto;
    font-family: Arial, sans-serif;
}
h1, h2 {
    text-align: center;
}
#chatbox {
    height: 500px;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("## 🤖 Иннополис Бот")
    gr.Markdown("Задавайте вопросы о Университете Иннополис")

    chatbot = gr.Chatbot(label="Диалог", elem_id="chatbox")
    with gr.Row():
        msg = gr.Textbox(show_label=False, placeholder="Введите вопрос...", scale=8)
        send = gr.Button("Отправить", scale=1)
    clear = gr.Button("🗑 Очистить чат")

    # Логика чата
    def chat_response(user_message, history):
        history = history or []
        if not user_message.strip():
            return history, ""  # если пустой ввод
        bot_reply = generate_response(user_message)
        history.append((user_message, bot_reply))
        return history, ""  # вернём чат и очистим поле

    send.click(chat_response, inputs=[msg, chatbot], outputs=[chatbot, msg])
    msg.submit(chat_response, inputs=[msg, chatbot], outputs=[chatbot, msg])
    clear.click(lambda: ([], ""), None, [chatbot, msg])


# === FastAPI ===
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/api/ask")
async def api_ask(request: Request):
    data = await request.json()
    question = data.get("question", "")
    answer = generate_response(question)
    return {"answer": answer}

# Встраиваем Gradio в FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)