MyRAGModel / app.py
rampagetrew's picture
Update app.py
db20aee verified
import gradio as gr
import torch
from transformers import pipeline
import os
import logging
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
print("🚀 Starting RAG Model Application...")
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Инициализация модели
generator = None
def load_model():
"""Загрузка модели"""
global generator
try:
# Начнем с легкой модели для тестирования
model_name = "microsoft/DialoGPT-medium"
logger.info(f"Loading model: {model_name}")
generator = pipeline(
"text-generation",
model=model_name,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
logger.info("✅ Model loaded successfully!")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
return False
def chat_response(message, history):
"""Обработчик сообщений для чата"""
if generator is None:
return "⚠️ Модель еще не загружена. Пожалуйста, подождите..."
try:
# Простая генерация ответа
response = generator(
message,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
generated_text = response[0]['generated_text']
# Убираем оригинальное сообщение из ответа
if generated_text.startswith(message):
generated_text = generated_text[len(message):].strip()
return generated_text
except Exception as e:
logger.error(f"Generation error: {e}")
return f"❌ Ошибка при генерации ответа: {str(e)}"
# Загружаем модель при старте
model_loaded = load_model()
# Создаем простой интерфейс
if model_loaded:
title = "🧠 RAG Model - Ready!"
description = "Модель успешно загружена и готова к работе!"
else:
title = "🧠 RAG Model - Error"
description = "Не удалось загрузить модель. Проверьте логи."
# Создаем интерфейс
demo = gr.ChatInterface(
fn=chat_response,
title=title,
description=description,
examples=[
"Привет! Как ты работаешь?",
"Расскажи о машинном обучении",
"Напиши короткий пример кода на Python"
]
)
if __name__ == "__main__":
logger.info("🌐 Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)