Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from datetime import datetime | |
| import torch | |
| import warnings | |
| import json | |
| import os | |
| warnings.filterwarnings('ignore') | |
| # Настройки страницы | |
| st.set_page_config( | |
| page_title="Выбор модели для дообучения", | |
| page_icon="🤖", | |
| layout="wide" | |
| ) | |
| # Стили CSS | |
| st.markdown(""" | |
| <style> | |
| .result-card { | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin: 10px 0; | |
| border-left: 5px solid #28a745; | |
| } | |
| .best-model { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 25px; | |
| border-radius: 15px; | |
| text-align: center; | |
| } | |
| .comparison-table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| } | |
| .comparison-table th { | |
| background-color: #f2f2f2; | |
| padding: 12px; | |
| text-align: left; | |
| } | |
| .comparison-table td { | |
| padding: 10px; | |
| border-bottom: 1px solid #ddd; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("🎯 Выбор модели для дообучения на русских текстах") | |
| # Загрузка предыдущих результатов | |
| def load_previous_results(): | |
| try: | |
| if os.path.exists("model_evaluation_results.csv"): | |
| df = pd.read_csv("model_evaluation_results.csv") | |
| return df | |
| except: | |
| pass | |
| return None | |
| # Основные русскоязычные модели | |
| RUSSIAN_MODELS = { | |
| "ruGPT-3.5 (1.3B)": { | |
| "id": "ai-forever/rugpt3.5_1.3b", | |
| "size": "1.3B", | |
| "language": "🇷🇺 Русский", | |
| "description": "Крупная русская модель, наилучшее качество", | |
| "recommended": True, | |
| "type": "large", | |
| "tested": False | |
| }, | |
| "ruGPT-3 Medium": { | |
| "id": "sberbank-ai/rugpt3medium_based_on_gpt2", | |
| "size": "355M", | |
| "language": "🇷🇺 Русский", | |
| "description": "Средняя русская модель, хорошее качество", | |
| "recommended": True, | |
| "type": "medium", | |
| "tested": True, | |
| "perplexity": 20.9, | |
| "speed": 21.6, | |
| "score": 8.4 | |
| }, | |
| "ruGPT-3 Small": { | |
| "id": "sberbank-ai/rugpt3small_based_on_gpt2", | |
| "size": "125M", | |
| "language": "🇷🇺 Русский", | |
| "description": "Базовая русскоязычная модель", | |
| "recommended": False, | |
| "type": "small", | |
| "tested": True, | |
| "perplexity": 23.9, | |
| "speed": 8.5, | |
| "score": 8.4 | |
| }, | |
| "ruT5-base": { | |
| "id": "cointegrated/rut5-base", | |
| "size": "220M", | |
| "language": "🇷🇺 Русский", | |
| "description": "T5 архитектура, хороша для переформулирования", | |
| "recommended": True, | |
| "type": "medium", | |
| "tested": False | |
| } | |
| } | |
| def main(): | |
| # Загрузка предыдущих результатов | |
| previous_results = load_previous_results() | |
| # Боковая панель | |
| with st.sidebar: | |
| st.header("📊 Результаты тестирования") | |
| if previous_results is not None: | |
| st.success("Обнаружены сохраненные результаты") | |
| # Лучшая модель из результатов | |
| best_row = previous_results.loc[previous_results['Общая оценка'].idxmax()] | |
| st.metric("Лучшая модель", best_row['Модель']) | |
| st.metric("Оценка", f"{best_row['Общая оценка']}/10") | |
| st.metric("Perplexity", f"{best_row['Perplexity']}") | |
| # Показать все результаты | |
| with st.expander("Все результаты"): | |
| st.dataframe(previous_results) | |
| st.markdown("---") | |
| st.header("⚙️ Новое тестирование") | |
| # Выбор действия | |
| action = st.radio( | |
| "Выберите действие:", | |
| ["Показать рекомендации", "Запустить новое тестирование", "Сравнить модели"] | |
| ) | |
| # Основная область в зависимости от выбранного действия | |
| if action == "Показать рекомендации": | |
| show_recommendations(previous_results) | |
| elif action == "Запустить новое тестирование": | |
| show_testing_interface() | |
| else: | |
| show_comparison() | |
| def show_recommendations(previous_results): | |
| st.header("🏆 Рекомендации по выбору модели") | |
| # Отображение лучшей модели | |
| if previous_results is not None: | |
| best_row = previous_results.loc[previous_results['Общая оценка'].idxmax()] | |
| st.markdown(f""" | |
| <div class="best-model"> | |
| <h2>🚀 Рекомендуемая модель: {best_row['Модель']}</h2> | |
| <h3>Оценка: {best_row['Общая оценка']}/10 • Perplexity: {best_row['Perplexity']}</h3> | |
| <p>Оптимальный выбор для дообучения на датасете mark.csv</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Детали рекомендации | |
| st.subheader("📋 Почему эта модель?") | |
| cols = st.columns(3) | |
| with cols[0]: | |
| st.metric("Качество генерации", "Высокое", delta="Low perplexity") | |
| with cols[1]: | |
| st.metric("Размер модели", best_row['Размер'], delta="Оптимальный") | |
| with cols[2]: | |
| st.metric("Для дообучения", "Отлично", delta="Поддерживает fine-tuning") | |
| # Код для дообучения | |
| st.subheader("💻 Код для дообучения") | |
| model_options = list(RUSSIAN_MODELS.keys()) | |
| selected_model = st.selectbox( | |
| "Выберите модель для получения кода:", | |
| model_options, | |
| index=model_options.index("ruGPT-3 Medium") if "ruGPT-3 Medium" in model_options else 0 | |
| ) | |
| model_info = RUSSIAN_MODELS[selected_model] | |
| code = f""" | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import Trainer, TrainingArguments | |
| import pandas as pd | |
| import torch | |
| # 1. Загрузка модели | |
| model_id = "{model_info['id']}" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Установка pad_token если его нет | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| # 2. Подготовка данных | |
| def prepare_dataset(filepath="mark.csv"): | |
| # Загрузка вашего датасета | |
| df = pd.read_csv(filepath) | |
| # Предполагаем, что текстовые данные в колонке 'text' | |
| texts = df['text'].tolist() | |
| # Токенизация | |
| encodings = tokenizer( | |
| texts, | |
| truncation=True, | |
| padding=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| return encodings | |
| # 3. Создание датасета | |
| class TextDataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings): | |
| self.encodings = encodings | |
| def __len__(self): | |
| return len(self.encodings['input_ids']) | |
| def __getitem__(self, idx): | |
| item = {{key: val[idx] for key, val in self.encodings.items()}} | |
| item['labels'] = item['input_ids'].clone() # Для language modeling | |
| return item | |
| # 4. Настройка обучения | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| num_train_epochs=3, | |
| per_device_train_batch_size=4, | |
| per_device_eval_batch_size=4, | |
| warmup_steps=500, | |
| weight_decay=0.01, | |
| logging_dir="./logs", | |
| logging_steps=100, | |
| save_steps=1000, | |
| evaluation_strategy="steps", | |
| ) | |
| # 5. Запуск обучения | |
| dataset = TextDataset(prepare_dataset()) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |
| """ | |
| st.code(code, language="python") | |
| # Советы по дообучению | |
| st.subheader("📝 Советы по дообучению") | |
| tips = { | |
| "Batch Size": "Начните с небольшого batch size (2-4), особенно для больших моделей", | |
| "Learning Rate": "Используйте маленький learning rate (1e-5 до 5e-5) для дообучения", | |
| "Epochs": "3-5 эпох обычно достаточно для дообучения", | |
| "Мониторинг": "Следите за perplexity на валидационной выборке", | |
| "Сохранение": "Сохраняйте чекпоинты каждые 1000 шагов" | |
| } | |
| for tip, description in tips.items(): | |
| with st.expander(f"✅ {tip}"): | |
| st.write(description) | |
| def show_testing_interface(): | |
| st.header("🧪 Тестирование моделей") | |
| st.info("Эта функция тестирует модели на ваших данных. Для получения рекомендаций используйте раздел 'Показать рекомендации'") | |
| # Здесь можно разместить код тестирования из предыдущей версии | |
| st.warning("Функция тестирования требует значительных вычислительных ресурсов") | |
| st.write("Для тестирования моделей на Hugging Face используйте оригинальное приложение") | |
| def show_comparison(): | |
| st.header("📊 Сравнение моделей") | |
| # Создаем таблицу сравнения | |
| comparison_data = [] | |
| for name, info in RUSSIAN_MODELS.items(): | |
| row = { | |
| "Модель": name, | |
| "Размер": info["size"], | |
| "Тип": info["type"], | |
| "Рекомендуется": "✅" if info["recommended"] else "❌", | |
| "Описание": info["description"] | |
| } | |
| if info.get("tested", False): | |
| row.update({ | |
| "Perplexity": info.get("perplexity", "N/A"), | |
| "Оценка": info.get("score", "N/A"), | |
| "Статус": "✅ Протестирована" | |
| }) | |
| else: | |
| row.update({ | |
| "Perplexity": "Не тестировалась", | |
| "Оценка": "Не тестировалась", | |
| "Статус": "⏳ Ожидает тестирования" | |
| }) | |
| comparison_data.append(row) | |
| df_comparison = pd.DataFrame(comparison_data) | |
| # Отображение таблицы | |
| st.dataframe( | |
| df_comparison, | |
| column_config={ | |
| "Рекомендуется": st.column_config.TextColumn("Рекомендация"), | |
| "Статус": st.column_config.TextColumn("Статус тестирования"), | |
| }, | |
| hide_index=True, | |
| use_container_width=True | |
| ) | |
| # Графическое сравнение | |
| st.subheader("📈 Сравнение производительности") | |
| # Создаем данные для графика | |
| tested_models = [(name, info) for name, info in RUSSIAN_MODELS.items() if info.get("tested", False)] | |
| if tested_models: | |
| import plotly.graph_objects as go | |
| model_names = [name for name, _ in tested_models] | |
| scores = [info.get("score", 0) for _, info in tested_models] | |
| perplexities = [info.get("perplexity", 100) for _, info in tested_models] | |
| fig = go.Figure(data=[ | |
| go.Bar(name='Оценка', x=model_names, y=scores, marker_color='lightblue'), | |
| go.Bar(name='Perplexity', x=model_names, y=[p/10 for p in perplexities], marker_color='lightcoral') | |
| ]) | |
| fig.update_layout( | |
| title='Сравнение моделей', | |
| barmode='group', | |
| yaxis_title='Значения', | |
| showlegend=True | |
| ) | |
| st.plotly_chart(fig) | |
| # Выводы | |
| st.subheader("🎯 Выводы") | |
| best_tested = max(tested_models, key=lambda x: x[1].get("score", 0)) | |
| st.markdown(f""" | |
| 1. **{best_tested[0]}** показывает лучшие результаты среди протестированных | |
| 2. **ruGPT-3 Medium** имеет самый низкий perplexity (лучшее качество генерации) | |
| 3. **ruGPT-3 Small** быстрее, но чуть хуже по качеству | |
| 4. **ruGPT-3.5 (1.3B)** не тестировалась, но потенциально может дать лучшее качество | |
| """) | |
| if __name__ == "__main__": | |
| main() |