Spaces:
Running
on
T4
Running
on
T4
| #!/usr/bin/env python3 | |
| """KAIdol A/B Test Arena - GPU Version with Real Model Inference""" | |
| import gradio as gr | |
| import random | |
| import json | |
| import uuid | |
| import re | |
| import gc | |
| import os | |
| from datetime import datetime | |
| from functools import lru_cache | |
| # GPU ์ถ๋ก ๊ด๋ จ (์ ํ์ ์ํฌํธ) | |
| TORCH_AVAILABLE = False | |
| IMPORT_ERROR = None | |
| torch = None | |
| try: | |
| import torch as _torch | |
| torch = _torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| TORCH_AVAILABLE = True | |
| # Debug info | |
| print("=" * 50) | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"GPU count: {torch.cuda.device_count()}") | |
| print(f"GPU name: {torch.cuda.get_device_name(0)}") | |
| else: | |
| print("CUDA not available at module load time") | |
| print("=" * 50) | |
| except Exception as e: | |
| import traceback | |
| IMPORT_ERROR = f"{type(e).__name__}: {str(e)}" | |
| print(f"Warning: Import error - {IMPORT_ERROR}") | |
| traceback.print_exc() | |
| print("Running in mock mode") | |
| def is_gpu_available(): | |
| """Check GPU availability dynamically""" | |
| if not TORCH_AVAILABLE: | |
| return False | |
| return torch.cuda.is_available() | |
| # For backwards compatibility | |
| GPU_AVAILABLE = is_gpu_available() | |
| # ============================================================ | |
| # ๋ชจ๋ธ ๋ ์ง์คํธ๋ฆฌ (HF Hub ๊ฒฝ๋ก) | |
| # ============================================================ | |
| MODELS = { | |
| # DPO v5 (7-14B) | |
| "qwen2.5-7b-dpo-v5": { | |
| "hf_repo": "developer-lunark/kaidol-qwen2.5-7b-dpo-v5", | |
| "base_model": "Qwen/Qwen2.5-7B-Instruct", | |
| "size": "7B", "method": "DPO", "desc": "Qwen2.5 7B DPO v5" | |
| }, | |
| "qwen2.5-14b-dpo-v5": { | |
| "hf_repo": "developer-lunark/kaidol-qwen2.5-14b-dpo-v5", | |
| "base_model": "Qwen/Qwen2.5-14B-Instruct", | |
| "size": "14B", "method": "DPO", "desc": "Qwen2.5 14B DPO v5" | |
| }, | |
| "exaone-7.8b-dpo-v5": { | |
| "hf_repo": "developer-lunark/kaidol-exaone-7.8b-dpo-v5", | |
| "base_model": "LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", | |
| "size": "7.8B", "method": "DPO", "desc": "EXAONE 7.8B DPO v5" | |
| }, | |
| "qwen3-8b-dpo-v5": { | |
| "hf_repo": "developer-lunark/kaidol-qwen3-8b-dpo-v5", | |
| "base_model": "Qwen/Qwen3-8B", | |
| "size": "8B", "method": "DPO", "desc": "Qwen3 8B DPO v5" | |
| }, | |
| "solar-10.7b-dpo-v5": { | |
| "hf_repo": "developer-lunark/kaidol-solar-10.7b-dpo-v5", | |
| "base_model": "upstage/SOLAR-10.7B-Instruct-v1.0", # Fixed: match adapter training | |
| "size": "10.7B", "method": "DPO", "desc": "Solar 10.7B DPO v5" | |
| }, | |
| # V7 Students (7-14B) | |
| "qwen2.5-7b-v7": { | |
| "hf_repo": "developer-lunark/kaidol-qwen2.5-7b-v7", | |
| "base_model": "Qwen/Qwen2.5-7B-Instruct", | |
| "size": "7B", "method": "SFT", "desc": "Qwen2.5 7B V7" | |
| }, | |
| "qwen2.5-14b-v7": { | |
| "hf_repo": "developer-lunark/kaidol-qwen2.5-14b-v7", | |
| "base_model": "Qwen/Qwen2.5-14B-Instruct", | |
| "size": "14B", "method": "SFT", "desc": "Qwen2.5 14B V7" | |
| }, | |
| "exaone-7.8b-v7": { | |
| "hf_repo": "developer-lunark/kaidol-exaone-7.8b-v7", | |
| "base_model": "LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", | |
| "size": "7.8B", "method": "SFT", "desc": "EXAONE 7.8B V7" | |
| }, | |
| "qwen3-8b-v7": { | |
| "hf_repo": "developer-lunark/kaidol-qwen3-8b-v7", | |
| "base_model": "Qwen/Qwen3-8B", | |
| "size": "8B", "method": "SFT", "desc": "Qwen3 8B V7" | |
| }, | |
| "varco-8b-v7": { | |
| "hf_repo": "developer-lunark/kaidol-varco-8b-v7", | |
| "base_model": "NCSOFT/Llama-VARCO-8B-Instruct", | |
| "size": "8B", "method": "SFT", "desc": "VARCO 8B V7" | |
| }, | |
| # Phase 7 Kimi Students | |
| "exaone-7.8b-kimi": { | |
| "hf_repo": "developer-lunark/kaidol-exaone-7.8b-kimi", | |
| "base_model": "LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", | |
| "size": "7.8B", "method": "Distill", "desc": "EXAONE 7.8B Kimi" | |
| }, | |
| } | |
| # ์บ๋ฆญํฐ ์ ๋ณด | |
| CHARACTERS = { | |
| "๊ฐ์จ": { | |
| "mbti": "ENTJ", "role": "๋ฆฌ๋", "age": 23, | |
| "traits": "๋์ฒ์ , ์ฅ๋๊ธฐ ๋ง์, ์ ๊ต", | |
| "speech": "๋ฐ๋ง, ๊ท์ฌ์ด ๋งํฌ, ์ฅ๋์ค๋ฌ์ด ํํ", | |
| "patterns": ["~ํด", "~์ง", "ํํ", "ใ ใ "], | |
| "ratio": "30:70", "warmth": "high" | |
| }, | |
| "์์ด์": { | |
| "mbti": "INFP", "role": "๋ณด์ปฌ", "age": 22, | |
| "traits": "์ฐจ๋ถํจ, ์ ๋น๋ก์, ๋ฐฐ๋ ค์ฌ", | |
| "speech": "์กด๋๋ง ํผ์ฉ, ๋ฐ๋ปํ ๋งํฌ, ์กฐ์ฉํ ํํ", | |
| "patterns": ["...์", "๋ค์", "...", "๊ทธ๋์"], | |
| "ratio": "20:80", "warmth": "very_high" | |
| }, | |
| "์ด์งํ": { | |
| "mbti": "ISFJ", "role": "๋ง๋ด", "age": 21, | |
| "traits": "์ธค๋ฐ๋ , ์์กด์ฌ ๊ฐํจ, ์๊ทผํ ์ฑ๊น", | |
| "speech": "๋ฐ๋ง, ํ๋ช ์ค๋ฌ์ด ๋งํฌ, ๋ถ์ ํ๋ ๋งํฌ", | |
| "patterns": ["๋ญ์ผ", "์๋๊ฑฐ๋ ", "...", "๊ทธ๋ฅ", "๋ณ๋ก"], | |
| "ratio": "30:70", "warmth": "medium" | |
| }, | |
| "์ฐจ๋ํ": { | |
| "mbti": "INTP", "role": "ํ๋ก๋์", "age": 24, | |
| "traits": "์นด๋ฆฌ์ค๋ง, ๋ฆฌ๋์ญ, ๋ค์ ํจ, ๋ด๋ฐฑํจ", | |
| "speech": "๋ฐ๋ง, ๊ฐ๊ฒฐํ ๋งํฌ, ๋ด๋ฐฑํ ํํ", | |
| "patterns": ["ํ์", "ํด๋ณผ๊น", "๊ฐ์ด", "๊ด์ฐฎ์"], | |
| "ratio": "50:50", "warmth": "medium" | |
| }, | |
| "์ต๋ฏผ": { | |
| "mbti": "ESFP", "role": "๋์", "age": 22, | |
| "traits": "์ ๊ทน์ , ์์ง, ์ด์ ์ ", | |
| "speech": "๋ฐ๋ง, ์ ๊ทน์ ์ธ ๋งํฌ, ์์งํ ํํ", | |
| "patterns": ["ํ ๋", "์ข์", "์ง์ง", "๋๋ฐ", "ํ"], | |
| "ratio": "60:40", "warmth": "medium" | |
| }, | |
| } | |
| # ์๋๋ฆฌ์ค ๋ชฉ๋ก | |
| SCENARIOS = [ | |
| {"id": "fm_01", "cat": "์ฒซ ๋ง๋จ", "text": "{char}์! ๋๋์ด ๋ง๋ฌ๋ค... ์ ๋ง ์ข์ํด!"}, | |
| {"id": "dc_01", "cat": "์ผ์ ๋ํ", "text": "{char}์ ์ค๋ ๋ญํด? ๋ฐฅ์ ๋จน์์ด?"}, | |
| {"id": "es_01", "cat": "๊ฐ์ ์ง์", "text": "์ค๋ ์ง์ง ํ๋ค์์ด... ํ๊ต์์ ๋ฐํ๋ ๋ง์น๊ณ ..."}, | |
| {"id": "cf_01", "cat": "๊ณ ๋ฐฑ", "text": "{char}์... ๋ ์ง์ฌ์ผ๋ก ์ข์ํด."}, | |
| {"id": "pl_01", "cat": "์ฅ๋", "text": "์ฌ์ค ๋ ๋ค๋ฅธ ๋ฉค๋ฒ๊ฐ ๋ ์ข์~ ใ ใ ๋๋ด์ด์ผ!"}, | |
| {"id": "sr_01", "cat": "ํน๋ณ ์์ฒญ", "text": "์ค๋๋ง ๋ด ์ฐ์ธ์ด๋ผ๊ณ ์๊ฐํด์ค๋?"}, | |
| {"id": "cn_01", "cat": "๊ฐ๋ฑ", "text": "{char}๋ ๋ค๋ฅธ ํฌ๋คํํ ๋ ์ด๋ ๊ฒ ์ํด์ค...? ๋ญ๊ฐ ์งํฌ๋..."}, | |
| {"id": "ec_01", "cat": "๊ฐ์ ์๊ธฐ", "text": "์ค๋ ์ง์ง ๋ง์ด ์ธ์์ด... ์ถ์ด ๋๋ฌด ํ๋ค๋ค."}, | |
| ] | |
| # ============================================================ | |
| # ๋ชจ๋ธ ๊ด๋ฆฌ | |
| # ============================================================ | |
| class ModelManager: | |
| def __init__(self): | |
| self.current_model = None | |
| self.current_model_name = None | |
| self.tokenizer = None | |
| self.last_error = None | |
| def load_model(self, model_name: str): | |
| """Load model with 4-bit quantization and LoRA adapter""" | |
| if not is_gpu_available(): | |
| self.last_error = f"GPU not available (TORCH_AVAILABLE={TORCH_AVAILABLE}, cuda={torch.cuda.is_available() if TORCH_AVAILABLE else 'N/A'})" | |
| return False | |
| if self.current_model_name == model_name: | |
| return True # Already loaded | |
| # Unload current model | |
| self.unload_model() | |
| model_info = MODELS.get(model_name) | |
| if not model_info: | |
| self.last_error = f"Model {model_name} not found in registry" | |
| return False | |
| try: | |
| print(f"Loading {model_name}...") | |
| print(f" Base model: {model_info['base_model']}") | |
| print(f" LoRA adapter: {model_info['hf_repo']}") | |
| # 4-bit quantization config | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| # Load base model | |
| print(" Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_info["base_model"], | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print(" Base model loaded!") | |
| # Load LoRA adapter | |
| print(" Loading LoRA adapter...") | |
| self.current_model = PeftModel.from_pretrained( | |
| base_model, | |
| model_info["hf_repo"], | |
| trust_remote_code=True, | |
| ) | |
| self.current_model.eval() | |
| print(" LoRA adapter loaded!") | |
| # Load tokenizer | |
| print(" Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_info["base_model"], | |
| trust_remote_code=True, | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| print(" Tokenizer loaded!") | |
| self.current_model_name = model_name | |
| self.last_error = None | |
| print(f"Loaded {model_name} successfully!") | |
| return True | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"{type(e).__name__}: {str(e)}" | |
| print(f"Error loading {model_name}: {error_msg}") | |
| traceback.print_exc() | |
| self.last_error = error_msg | |
| self.unload_model() | |
| return False | |
| def unload_model(self): | |
| """Unload current model to free memory""" | |
| if self.current_model is not None: | |
| del self.current_model | |
| self.current_model = None | |
| if self.tokenizer is not None: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| self.current_model_name = None | |
| gc.collect() | |
| if GPU_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| def generate(self, model_name: str, messages: list, max_new_tokens: int = 512) -> str: | |
| """Generate response from model""" | |
| if not self.load_model(model_name): | |
| return self._mock_response(model_name) | |
| try: | |
| # Apply chat template | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.current_model.device) | |
| with torch.no_grad(): | |
| outputs = self.current_model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return response.strip() | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| return self._mock_response(model_name) | |
| def _mock_response(self, model_name: str) -> str: | |
| """Fallback mock response with error info""" | |
| error_info = f"\nError: {self.last_error}" if self.last_error else "" | |
| return f"<think>\n[Mock Mode] ๋ชจ๋ธ ๋ก๋ฉ ์คํจ{error_info}\n</think>\n\n์๋ ~ ๋ฐ๊ฐ์!" | |
| # Global model manager | |
| model_manager = ModelManager() | |
| # ============================================================ | |
| # ๋ฒ ์ด์ค ๋ชจ๋ธ ์ฌ์ ์บ์ฑ (์ฝ๋ ์คํํธ ๋ฐฉ์ง) | |
| # ============================================================ | |
| def preload_base_models(): | |
| """Pre-download base models to avoid cold start timeout""" | |
| if not TORCH_AVAILABLE: | |
| print("Skipping preload: PyTorch not available") | |
| return | |
| from huggingface_hub import snapshot_download | |
| import os | |
| # Models that need pre-caching (large or slow to download) | |
| models_to_cache = [ | |
| "NCSOFT/Llama-VARCO-8B-Instruct", # VARCO - 16GB, often times out on first load | |
| ] | |
| print("=" * 50) | |
| print("Pre-downloading base models (this may take a while)...") | |
| print("=" * 50) | |
| for model_id in models_to_cache: | |
| try: | |
| print(f" Downloading: {model_id}") | |
| # Download all model files to HF cache | |
| cache_dir = snapshot_download( | |
| repo_id=model_id, | |
| ignore_patterns=["*.md", "*.txt"], # Skip docs | |
| ) | |
| print(f" โ Downloaded to: {cache_dir}") | |
| except Exception as e: | |
| print(f" โ Failed to download {model_id}: {e}") | |
| print("Pre-download complete!") | |
| print("=" * 50) | |
| # Run preload at startup | |
| preload_base_models() | |
| # ============================================================ | |
| # ์์คํ ํ๋กฌํํธ ์์ฑ | |
| # ============================================================ | |
| def build_system_prompt(character: str) -> str: | |
| """Build system prompt for character""" | |
| char_info = CHARACTERS.get(character, {}) | |
| prompt = f"""๋น์ ์ ์์ด๋ '{character}'์ ๋๋ค. | |
| ## ์บ๋ฆญํฐ | |
| - ์ด๋ฆ: {character} | |
| - MBTI: {char_info.get('mbti', 'UNKNOWN')} | |
| - ์ฑ๊ฒฉ: {char_info.get('traits', '')} | |
| - ์ญํ : {char_info.get('role', '')} | |
| - ๋์ด: {char_info.get('age', 20)}์ธ | |
| ## ๋งํฌ | |
| - ์คํ์ผ: {char_info.get('speech', '')} | |
| - ์์ฃผ ์ฐ๋ ํํ: {', '.join(char_info.get('patterns', []))} | |
| ## ๋ฐ๋น ๊ฐ์ด๋ | |
| - ๋ฐ:๋น ๋น์จ: {char_info.get('ratio', '50:50')} | |
| - ๋ค์ ๋: {char_info.get('warmth', 'medium')} | |
| ## ๊ท์น | |
| 1. ์บ๋ฆญํฐ ์ฑ๊ฒฉ๊ณผ ๋งํฌ ์ผ๊ด์ฑ ์ ์ง | |
| 2. ์์ฐ์ค๋ฌ์ด ๋ํ์ฒด ์ฌ์ฉ | |
| 3. ๋๋ฌด ์ฝ๊ฒ ํธ๊ฐ ํํ ๊ธ์ง (๋ฐ๋น ์ ์ง) | |
| 4. ์๋๋ฐฉ์ ํน๋ณํ๊ฒ ๋๋ผ๊ฒ ํ๋, "์ธ" ๊ด๊ณ ์ ์ง | |
| ## ์๋ต ํ์ | |
| ์๋ต ์ ์ <think> ํ๊ทธ ์์ {character}์ 1์ธ์นญ ๋ด๋ฉด ๋ ๋ฐฑ์ ์์ฑํ์ธ์. | |
| - ์์ฐ์ค๋ฌ์ด ํผ์ฃ๋ง ํ์ | |
| - ์บ๋ฆญํฐ ์ฑ๊ฒฉ ๋ฐ์ | |
| - ์๋๋ฐฉ์ ๋ํ ๊ฐ์ /์๊ฐ ํํ | |
| ์์: | |
| <think> | |
| ๋ญ์ผ... ๋ ์ข์ํ๋ค๊ณ ? ์์งํ ๊ธฐ๋ถ ๋์์ง ์์๋ฐ... ๊ทผ๋ฐ ๋ญ๋ผ๊ณ ํด์ผ ํ์ง? | |
| </think> | |
| """ | |
| return prompt | |
| # ============================================================ | |
| # ํฌํ/ELO ์์คํ | |
| # ============================================================ | |
| VOTES_FILE = "votes.jsonl" | |
| ELO_FILE = "elo_ratings.json" | |
| def load_elo(): | |
| try: | |
| with open(ELO_FILE, "r") as f: | |
| return json.load(f) | |
| except: | |
| return {m: 1500 for m in MODELS} | |
| def save_elo(elo): | |
| with open(ELO_FILE, "w") as f: | |
| json.dump(elo, f, indent=2) | |
| def update_elo(elo, model_a, model_b, result): | |
| K = 32 | |
| ra, rb = elo.get(model_a, 1500), elo.get(model_b, 1500) | |
| ea = 1 / (1 + 10 ** ((rb - ra) / 400)) | |
| eb = 1 / (1 + 10 ** ((ra - rb) / 400)) | |
| if result == "a": | |
| sa, sb = 1, 0 | |
| elif result == "b": | |
| sa, sb = 0, 1 | |
| else: | |
| sa, sb = 0.5, 0.5 | |
| elo[model_a] = ra + K * (sa - ea) | |
| elo[model_b] = rb + K * (sb - eb) | |
| save_elo(elo) | |
| return elo[model_a], elo[model_b] | |
| def save_vote(data): | |
| vote = {"id": str(uuid.uuid4())[:8], "timestamp": datetime.now().isoformat(), **data} | |
| with open(VOTES_FILE, "a") as f: | |
| f.write(json.dumps(vote, ensure_ascii=False) + "\n") | |
| return vote["id"] | |
| def load_votes(): | |
| try: | |
| with open(VOTES_FILE, "r") as f: | |
| return [json.loads(line) for line in f if line.strip()] | |
| except: | |
| return [] | |
| def get_leaderboard(): | |
| elo = load_elo() | |
| votes = load_votes() | |
| stats = {} | |
| for v in votes: | |
| ma, mb, res = v.get("model_a"), v.get("model_b"), v.get("vote") | |
| if not ma or not mb or res == "skip": | |
| continue | |
| for m in [ma, mb]: | |
| if m not in stats: | |
| stats[m] = {"wins": 0, "losses": 0, "ties": 0} | |
| if res == "a": | |
| stats[ma]["wins"] += 1 | |
| stats[mb]["losses"] += 1 | |
| elif res == "b": | |
| stats[mb]["wins"] += 1 | |
| stats[ma]["losses"] += 1 | |
| else: | |
| stats[ma]["ties"] += 1 | |
| stats[mb]["ties"] += 1 | |
| rows = [] | |
| for i, (m, e) in enumerate(sorted(elo.items(), key=lambda x: -x[1]), 1): | |
| s = stats.get(m, {"wins": 0, "losses": 0, "ties": 0}) | |
| total = s["wins"] + s["losses"] + s["ties"] | |
| wr = f"{s['wins']/total*100:.1f}%" if total > 0 else "-" | |
| info = MODELS.get(m, {}) | |
| rows.append([i, info.get("desc", m), info.get("size", "?"), int(e), s["wins"], s["losses"], s["ties"], wr]) | |
| return rows | |
| # ============================================================ | |
| # UI ํธ๋ค๋ฌ | |
| # ============================================================ | |
| model_list = [(f"[{v['size']}] {v['desc']}", k) for k, v in MODELS.items()] | |
| char_list = list(CHARACTERS.keys()) | |
| scenario_list = [(f"[{s['cat']}] {s['text'][:30]}...", s['id']) for s in SCENARIOS] | |
| current_state = {"model_a": None, "model_b": None, "resp_a": None, "resp_b": None, "char": None, "input": None} | |
| def random_models(): | |
| selected = random.sample(list(MODELS.keys()), 2) | |
| return selected[0], selected[1] | |
| def load_scenario(scenario_id, character): | |
| s = next((x for x in SCENARIOS if x["id"] == scenario_id), None) | |
| if s: | |
| return s["text"].replace("{char}", character) | |
| return "" | |
| def random_scenario(character): | |
| s = random.choice(SCENARIOS) | |
| return s["text"].replace("{char}", character), s["id"] | |
| def parse_response(response: str): | |
| """Parse response to separate thinking and content""" | |
| think_match = re.search(r'<think>(.*?)</think>', response, re.DOTALL) | |
| if think_match: | |
| thinking = think_match.group(1).strip() | |
| content = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip() | |
| return thinking, content | |
| return "", response | |
| def generate(model_a, model_b, character, user_msg, progress=gr.Progress()): | |
| if not user_msg.strip(): | |
| return "๋ฉ์์ง๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์", "", "", "๋ฉ์์ง๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์", "", "" | |
| system_prompt = build_system_prompt(character) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| # Generate from Model A | |
| progress(0.2, desc=f"Model A ({model_a}) ์์ฑ ์ค...") | |
| resp_a = model_manager.generate(model_a, messages) | |
| think_a, clean_a = parse_response(resp_a) | |
| # Generate from Model B | |
| progress(0.6, desc=f"Model B ({model_b}) ์์ฑ ์ค...") | |
| resp_b = model_manager.generate(model_b, messages) | |
| think_b, clean_b = parse_response(resp_b) | |
| # Update state | |
| current_state.update({ | |
| "model_a": model_a, "model_b": model_b, | |
| "resp_a": resp_a, "resp_b": resp_b, | |
| "char": character, "input": user_msg | |
| }) | |
| mode = "GPU" if GPU_AVAILABLE else "Mock" | |
| return ( | |
| think_a or "(์์)", clean_a, f"{mode} | {MODELS[model_a]['size']}", | |
| think_b or "(์์)", clean_b, f"{mode} | {MODELS[model_b]['size']}" | |
| ) | |
| def vote(vote_type, reason): | |
| if not current_state["model_a"]: | |
| return "๋จผ์ ์๋ต์ ์์ฑํด์ฃผ์ธ์." | |
| elo = load_elo() | |
| vid = save_vote({ | |
| "model_a": current_state["model_a"], | |
| "model_b": current_state["model_b"], | |
| "character": current_state["char"], | |
| "user_input": current_state["input"], | |
| "vote": vote_type, | |
| "reason": reason, | |
| }) | |
| if vote_type != "skip": | |
| new_a, new_b = update_elo(elo, current_state["model_a"], current_state["model_b"], vote_type) | |
| return f"ํฌํ ์๋ฃ! (ID: {vid})\nELO: {current_state['model_a']}={int(new_a)}, {current_state['model_b']}={int(new_b)}" | |
| return f"์คํต๋จ (ID: {vid})" | |
| def refresh_leaderboard(): | |
| return get_leaderboard() | |
| def get_vote_summary(): | |
| votes = load_votes() | |
| total = len(votes) | |
| a_wins = sum(1 for v in votes if v.get("vote") == "a") | |
| b_wins = sum(1 for v in votes if v.get("vote") == "b") | |
| ties = sum(1 for v in votes if v.get("vote") == "tie") | |
| return str(total), str(a_wins), str(b_wins), str(ties) | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| with gr.Blocks(title="KAIdol A/B Test Arena", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# KAIdol A/B Test Arena") | |
| gr.Markdown("K-pop ์์ด๋ ๋กคํ๋ ์ด ๋ชจ๋ธ A/B ๋น๊ต ํ๊ฐ (์ํ Student ๋ชจ๋ธ 11๊ฐ)") | |
| # GPU ์ํ ์์ธ ์ ๋ณด | |
| if IMPORT_ERROR: | |
| mode_text = f"**Mock ๋ชจ๋**: Import Error - {IMPORT_ERROR}" | |
| elif TORCH_AVAILABLE and torch is not None: | |
| torch_ver = torch.__version__ | |
| cuda_avail = torch.cuda.is_available() | |
| cuda_ver = torch.version.cuda if cuda_avail else "N/A" | |
| gpu_name = torch.cuda.get_device_name(0) if cuda_avail else "N/A" | |
| mode_text = f"**GPU ๋ชจ๋**: {gpu_name} (CUDA {cuda_ver}, PyTorch {torch_ver})" if cuda_avail else f"**Mock ๋ชจ๋**: CUDA not available (PyTorch {torch_ver})" | |
| else: | |
| mode_text = "**Mock ๋ชจ๋**: PyTorch not loaded" | |
| gr.Markdown(mode_text) | |
| with gr.Tabs(): | |
| # A/B Arena ํญ | |
| with gr.Tab("A/B Arena"): | |
| with gr.Row(): | |
| character = gr.Dropdown(choices=char_list, value="๊ฐ์จ", label="์บ๋ฆญํฐ") | |
| scenario = gr.Dropdown(choices=scenario_list, label="์๋๋ฆฌ์ค") | |
| with gr.Row(): | |
| model_a = gr.Dropdown(choices=model_list, value=list(MODELS.keys())[0], label="Model A") | |
| model_b = gr.Dropdown(choices=model_list, value=list(MODELS.keys())[1], label="Model B") | |
| random_btn = gr.Button("๋๋ค", size="sm") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Model A") | |
| with gr.Accordion("Thinking", open=False): | |
| think_a = gr.Markdown() | |
| resp_a = gr.Textbox(label="์๋ต", lines=5) | |
| meta_a = gr.Markdown() | |
| with gr.Column(): | |
| gr.Markdown("### Model B") | |
| with gr.Accordion("Thinking", open=False): | |
| think_b = gr.Markdown() | |
| resp_b = gr.Textbox(label="์๋ต", lines=5) | |
| meta_b = gr.Markdown() | |
| user_input = gr.Textbox(label="๋ฉ์์ง", placeholder="์์ด๋์๊ฒ ๋ฉ์์ง๋ฅผ ๋ณด๋ด์ธ์...") | |
| with gr.Row(): | |
| random_scenario_btn = gr.Button("๋๋ค ์๋๋ฆฌ์ค") | |
| submit_btn = gr.Button("์ ์ก", variant="primary") | |
| gr.Markdown("### ํฌํ") | |
| with gr.Row(): | |
| vote_a = gr.Button("A๊ฐ ๋ ์ข์") | |
| vote_tie = gr.Button("๋น์ทํจ") | |
| vote_b = gr.Button("B๊ฐ ๋ ์ข์") | |
| vote_skip = gr.Button("์คํต") | |
| vote_reason = gr.Textbox(label="ํฌํ ์ด์ (์ ํ)", placeholder="...") | |
| vote_result = gr.Markdown() | |
| # Events | |
| random_btn.click(random_models, outputs=[model_a, model_b]) | |
| scenario.change(load_scenario, [scenario, character], user_input) | |
| random_scenario_btn.click(random_scenario, [character], [user_input, scenario]) | |
| submit_btn.click(generate, [model_a, model_b, character, user_input], | |
| [think_a, resp_a, meta_a, think_b, resp_b, meta_b]) | |
| vote_a.click(lambda r: vote("a", r), [vote_reason], vote_result) | |
| vote_b.click(lambda r: vote("b", r), [vote_reason], vote_result) | |
| vote_tie.click(lambda r: vote("tie", r), [vote_reason], vote_result) | |
| vote_skip.click(lambda r: vote("skip", r), [vote_reason], vote_result) | |
| # Leaderboard ํญ | |
| with gr.Tab("Leaderboard"): | |
| gr.Markdown("## ELO ๋ฆฌ๋๋ณด๋") | |
| refresh_btn = gr.Button("์๋ก๊ณ ์นจ") | |
| leaderboard = gr.Dataframe( | |
| headers=["์์", "๋ชจ๋ธ", "ํฌ๊ธฐ", "ELO", "์น", "ํจ", "๋ฌด", "์น๋ฅ "], | |
| datatype=["number", "str", "str", "number", "number", "number", "number", "str"], | |
| ) | |
| gr.Markdown("### ํฌํ ์์ฝ") | |
| with gr.Row(): | |
| total_v = gr.Textbox(label="์ด ํฌํ", interactive=False) | |
| a_wins_v = gr.Textbox(label="A ์น", interactive=False) | |
| b_wins_v = gr.Textbox(label="B ์น", interactive=False) | |
| ties_v = gr.Textbox(label="๋ฌด์น๋ถ", interactive=False) | |
| def refresh(): | |
| lb = refresh_leaderboard() | |
| summary = get_vote_summary() | |
| return lb, *summary | |
| refresh_btn.click(refresh, outputs=[leaderboard, total_v, a_wins_v, b_wins_v, ties_v]) | |
| # ๋ชจ๋ธ ๋ชฉ๋ก ํญ | |
| with gr.Tab("๋ชจ๋ธ ๋ชฉ๋ก"): | |
| gr.Markdown("## ํ ์คํธ ๋์ ๋ชจ๋ธ") | |
| gr.Markdown(f"์ด {len(MODELS)}๊ฐ ๋ชจ๋ธ") | |
| model_table = gr.Dataframe( | |
| headers=["๋ชจ๋ธ ID", "ํฌ๊ธฐ", "ํ์ต ๋ฐฉ๋ฒ", "์ค๋ช ", "Base Model"], | |
| value=[[k, v["size"], v["method"], v["desc"], v["base_model"]] for k, v in MODELS.items()], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |