Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| from fastapi import FastAPI, Request | |
| import uvicorn | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # === Модель === | |
| model_id = "cody82/unitrip" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| context = ( | |
| "Университет Иннополис был основан в 2012 году. " | |
| "Это современный вуз в России, специализирующийся на IT и робототехнике, " | |
| "расположенный в городе Иннополис, Татарстан.\n" | |
| ) | |
| def generate_response(question): | |
| prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {question}\nОтвет:" | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=150, | |
| temperature=0.8, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| if "Ответ:" in output: | |
| answer = output.split("Ответ:")[-1].strip() | |
| else: | |
| answer = output[len(prompt):].strip() | |
| return answer | |
| # === Gradio UI === | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto; | |
| font-family: Arial, sans-serif; | |
| } | |
| h1, h2 { | |
| text-align: center; | |
| } | |
| #chatbox { | |
| height: 500px; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("## 🤖 Иннополис Бот") | |
| gr.Markdown("Задавайте вопросы о Университете Иннополис") | |
| chatbot = gr.Chatbot(label="Диалог", elem_id="chatbox") | |
| with gr.Row(): | |
| msg = gr.Textbox(show_label=False, placeholder="Введите вопрос...", scale=8) | |
| send = gr.Button("Отправить", scale=1) | |
| clear = gr.Button("🗑 Очистить чат") | |
| # Логика чата | |
| def chat_response(user_message, history): | |
| history = history or [] | |
| if not user_message.strip(): | |
| return history, "" # если пустой ввод | |
| bot_reply = generate_response(user_message) | |
| history.append((user_message, bot_reply)) | |
| return history, "" # вернём чат и очистим поле | |
| send.click(chat_response, inputs=[msg, chatbot], outputs=[chatbot, msg]) | |
| msg.submit(chat_response, inputs=[msg, chatbot], outputs=[chatbot, msg]) | |
| clear.click(lambda: ([], ""), None, [chatbot, msg]) | |
| # === FastAPI === | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def api_ask(request: Request): | |
| data = await request.json() | |
| question = data.get("question", "") | |
| answer = generate_response(question) | |
| return {"answer": answer} | |
| # Встраиваем Gradio в FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |