Spaces:
Sleeping
Sleeping
| """ | |
| Finnish Dental QA v3 - Optimized for Hourly GPU Billing | |
| This version is optimized for regular GPU Spaces (T4, L4, A100) that charge by the hour. | |
| The model stays on GPU throughout the session for faster responses since billing is | |
| for the full hour regardless. | |
| v3 CHANGES FROM v2: | |
| - Model: Finnish-DentalQA-v3 (full fine-tuned on Ahma-2-4B-Instruct) | |
| - Precision: BF16 (required for Ahma-2/Gemma-based architecture) | |
| - Parameters: 4B (vs 3B in v2) | |
| - Max new tokens: 800 (up from 600) | |
| - No PEFT/LoRA - this is a fully merged model | |
| IMPORTANT NOTE ABOUT ZEROGPU: | |
| As of late 2024, ZeroGPU has compatibility issues with Gradio's ChatInterface that cause | |
| "LookupError: progress context variable" errors. These appear to be infrastructure-level | |
| incompatibilities between ZeroGPU's multiprocessing system and ChatInterface's progress | |
| handling. Regular GPU Spaces work reliably without these issues. | |
| REQUIREMENTS.TXT NOTE: | |
| A requirements.txt file is needed listing the core dependencies (transformers, torch, | |
| accelerate, gradio) but HuggingFace Spaces often ignores version constraints and installs | |
| whatever versions it considers compatible with their infrastructure. Version pinning | |
| like "gradio>=4.44.1,<5" may be overridden. This is generally fine as long as the core | |
| libraries are present. | |
| PERSISTENT STORAGE SETUP: | |
| To avoid re-downloading the ~9GB model on every restart: | |
| 1. Enable persistent storage in Space settings (Small/20GB recommended) | |
| 2. Add environment variable: HF_HOME = /data/.huggingface | |
| 3. This makes transformers cache models in persistent storage instead of temp directory | |
| 4. Without this variable, models download fresh every restart even with persistent storage | |
| TO SWITCH BACK TO ZEROGPU LATER (when compatibility is fixed): | |
| 1. Change model loading: .to("cpu") instead of .to("cuda") | |
| 2. Add back: @spaces.GPU(duration=120) decorator to respond function | |
| 3. Add back: model.to("cuda") at start of respond function | |
| 4. Add back: model.to("cpu") in finally block | |
| 5. Add back: torch.cuda.empty_cache() calls | |
| 6. Change hardware to ZeroGPU in Space settings | |
| 7. Remove progress parameter from respond function signature (known issue) | |
| Current setup: Model loads directly to GPU and stays there for optimal performance | |
| with hourly billing model. | |
| """ | |
| import os | |
| os.environ["OMP_NUM_THREADS"] = "1" # Suppress libgomp warning | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| import torch, threading, time | |
| import gc | |
| # CUDA optimizations for better performance on T4/L4/A100 | |
| # Remove or modify these if switching to CPU-based hardware or ZeroGPU | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| # ---------------- Configuration ---------------- | |
| # Key variables - adjust these based on your hardware and requirements | |
| MODEL_MAX_CONTEXT = 4096 # Extended beyond 2048 training context (Ahma-2 supports up to 128K) | |
| GEN_MAX_NEW = 800 # Max new tokens per response (auto-adjusts if > 30% of context) | |
| CONCURRENCY_LIMIT = 2 # Simultaneous users (T4: 2, L4: 3-4, A100: 8+, adjust per hardware) | |
| # Auto-adjust generation length for different model sizes | |
| if GEN_MAX_NEW > MODEL_MAX_CONTEXT * 0.3: # If >30% of context, scale down | |
| GEN_MAX_NEW = int(MODEL_MAX_CONTEXT * 0.3) | |
| print(f"Auto-adjusted GEN_MAX_NEW to {GEN_MAX_NEW} tokens (30% of {MODEL_MAX_CONTEXT} context)") | |
| # Calculate dynamic safety buffer based on context size | |
| SAFETY_BUFFER = max(16, MODEL_MAX_CONTEXT // 128) # 16 tokens minimum, scales with context size | |
| # ---------------- Model ---------------- | |
| # Load model directly to GPU for hourly billing efficiency | |
| # v3 uses BF16 (required for Ahma-2/Gemma-based architecture) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "ducklingcodehouse/Finnish-DentalQA-v3", | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") # Keep on GPU since we're paying hourly | |
| tokenizer = AutoTokenizer.from_pretrained("ducklingcodehouse/Finnish-DentalQA-v3") | |
| system_prompt = """Olet kokenut suomalainen hammaslääkäri. Vastaat ammattimaisesti kollegojesi | |
| kysymyksiin käyttäen oikeaa hammaslääketieteellistä terminologiaa ja viittaat Käypä hoito | |
| -suosituksiin kun relevanttia.""" | |
| # ---------------- Helpers ---------------- | |
| def count_tokens_estimate(text: str) -> int: | |
| return len(text) // 4 # Finnish-ish heuristic | |
| def get_dynamic_input_max(model, max_new=GEN_MAX_NEW, total_context=MODEL_MAX_CONTEXT): | |
| """Get dynamic context limits with error handling""" | |
| try: | |
| # Try to get model's actual context size | |
| model_ctx = getattr(model.config, "max_position_embeddings", None) | |
| if not isinstance(model_ctx, int) or model_ctx <= 0 or model_ctx > 128_000: | |
| model_ctx = int(getattr(tokenizer, "model_max_length", total_context)) | |
| # Use the configured context size, but validate against model limits | |
| effective_context = min(total_context, model_ctx) | |
| input_max = max(512, effective_context - max_new - SAFETY_BUFFER) | |
| if effective_context != total_context: | |
| print(f"Using model's max context {effective_context} instead of configured {total_context}") | |
| return input_max | |
| except Exception: | |
| # Fallback calculation | |
| fallback_input = max(512, total_context - max_new - SAFETY_BUFFER) | |
| return fallback_input | |
| def build_messages_with_budget(history, new_message, input_max): | |
| """ | |
| Build message list that fits within input budget by keeping newest complete exchanges. | |
| Uses the full calculated budget efficiently without arbitrary sub-limits. | |
| """ | |
| # Calculate required tokens for fixed components | |
| system_tokens = count_tokens_estimate(system_prompt) | |
| current_tokens = count_tokens_estimate(new_message) | |
| # Single safety margin - scales with model size | |
| available_for_history = input_max - system_tokens - current_tokens - SAFETY_BUFFER | |
| # Always include system and current message | |
| system_msg = {"role": "system", "content": system_prompt} | |
| current_msg = {"role": "user", "content": new_message} | |
| # If no room for history, return minimal viable context | |
| if available_for_history <= 0: | |
| return [system_msg, current_msg] | |
| # Group history into complete user-assistant pairs | |
| exchanges = [] | |
| i = 0 | |
| while i < len(history): | |
| if i < len(history) and history[i].get("role") == "user": | |
| user_msg = history[i] | |
| assistant_msg = None | |
| # Look for corresponding assistant message | |
| if i + 1 < len(history) and history[i + 1].get("role") == "assistant": | |
| assistant_msg = history[i + 1] | |
| i += 2 | |
| else: | |
| i += 1 | |
| exchanges.append((user_msg, assistant_msg)) | |
| else: | |
| i += 1 | |
| # Add exchanges from newest to oldest using full available budget | |
| kept_exchanges = [] | |
| used_tokens = 0 | |
| for user_msg, assistant_msg in reversed(exchanges): | |
| # Calculate tokens for this complete exchange | |
| exchange_tokens = count_tokens_estimate(user_msg.get("content", "")) | |
| if assistant_msg: | |
| exchange_tokens += count_tokens_estimate(assistant_msg.get("content", "")) | |
| # Check if we can fit this exchange within available budget | |
| if used_tokens + exchange_tokens > available_for_history: | |
| break # Stop - this exchange would overflow the history budget | |
| # Keep this exchange | |
| kept_exchanges.insert(0, (user_msg, assistant_msg)) | |
| used_tokens += exchange_tokens | |
| # Build final message list | |
| messages = [system_msg] | |
| # Add kept exchanges in chronological order | |
| for user_msg, assistant_msg in kept_exchanges: | |
| messages.append(user_msg) | |
| if assistant_msg: | |
| messages.append(assistant_msg) | |
| # Add current message | |
| messages.append(current_msg) | |
| return messages | |
| def safe_tokenize_with_budget(msgs, input_max, fallback_message): | |
| """ | |
| Tokenize messages with proper budget validation and smart fallbacks. | |
| Uses progressively more available context space rather than arbitrary limits. | |
| """ | |
| try: | |
| # Primary strategy: Use the carefully constructed message list | |
| enc = tokenizer.apply_chat_template( | |
| msgs, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=False, # We handle trimming at message level | |
| return_attention_mask=True | |
| ) | |
| # Validate result length | |
| if isinstance(enc, dict): | |
| input_ids = enc["input_ids"] | |
| else: | |
| input_ids = enc | |
| if input_ids.shape[1] <= input_max: | |
| return enc | |
| else: | |
| raise Exception(f"Budget-trimmed messages still {input_ids.shape[1]} tokens > {input_max} limit") | |
| except Exception as e: | |
| print(f"Primary tokenization failed: {e}") | |
| # Fallback 1: System + current message only (uses ~80% of available space) | |
| try: | |
| system_tokens = count_tokens_estimate(system_prompt) | |
| safety_buffer = 16 | |
| # Use 80% of remaining space for current message to leave room for formatting | |
| available_for_current = int((input_max - system_tokens - SAFETY_BUFFER) * 0.8) | |
| max_chars_for_current = available_for_current * 4 # Convert back to chars | |
| trimmed_message = fallback_message[:max_chars_for_current] if len(fallback_message) > max_chars_for_current else fallback_message | |
| fallback_msgs = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": trimmed_message} | |
| ] | |
| enc = tokenizer.apply_chat_template( | |
| fallback_msgs, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, # Safe truncation as final backstop | |
| max_length=input_max, | |
| return_attention_mask=True | |
| ) | |
| print(f"Using fallback tokenization (system + {len(trimmed_message)} chars of current message)") | |
| return enc | |
| except Exception as e2: | |
| print(f"Fallback tokenization failed: {e2}") | |
| # Emergency fallback: Minimal message using 50% of available space | |
| try: | |
| emergency_tokens = input_max // 2 # Use half the available context | |
| emergency_chars = emergency_tokens * 4 # Convert to chars | |
| emergency_message = fallback_message[:emergency_chars] | |
| emergency_msgs = [{"role": "user", "content": emergency_message}] | |
| enc = tokenizer.apply_chat_template( | |
| emergency_msgs, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=emergency_tokens, | |
| return_attention_mask=True | |
| ) | |
| print(f"Using emergency tokenization ({emergency_chars} chars, ~{emergency_tokens} tokens)") | |
| return enc | |
| except Exception as e3: | |
| print(f"All tokenization strategies failed: {e3}") | |
| raise Exception("Unable to tokenize input - message may be too long") | |
| def safe_generate(model, generation_kwargs): | |
| """Safe generation wrapper with error handling""" | |
| try: | |
| with torch.no_grad(): | |
| model.generate(**generation_kwargs) | |
| except torch.cuda.OutOfMemoryError: | |
| print("CUDA out of memory during generation") | |
| # For hourly billing, we keep model on GPU but clear cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"Generation thread error: {e}") | |
| # ---------------- Chat Function ---------------- | |
| # No @spaces.GPU decorator needed since we're keeping model on GPU | |
| def respond(message, history): | |
| # Immediate feedback | |
| yield "Hetkinen..." | |
| try: | |
| # Model already on GPU - no need to move it | |
| # Use configurable max_new tokens | |
| max_new = GEN_MAX_NEW | |
| input_max = get_dynamic_input_max(model, max_new) | |
| # Check if current message itself is too long | |
| current_msg_tokens = count_tokens_estimate(message) | |
| if current_msg_tokens > input_max - 200: | |
| yield "Anteeksi, viestisi on liian pitkä. Yritä lyhyempää kysymystä." | |
| return | |
| # Build message list with proper budget management (no arbitrary sub-limits) | |
| msgs = build_messages_with_budget(history, message, input_max) | |
| # Tokenize with budget-aware approach | |
| try: | |
| enc = safe_tokenize_with_budget(msgs, input_max, message) | |
| except Exception as e: | |
| yield f"Anteeksi, viestin käsittelyssä tapahtui virhe. Yritä lyhyempää kysymystä." | |
| return | |
| # Handle encoding format - tensors already on correct device | |
| if isinstance(enc, dict): | |
| input_ids = enc["input_ids"].to("cuda") | |
| attention_mask = enc["attention_mask"].to("cuda") | |
| else: | |
| input_ids = enc.to("cuda") | |
| attention_mask = torch.ones_like(input_ids).to("cuda") | |
| # Final safety check | |
| if input_ids.shape[1] > input_max: | |
| yield "Anteeksi, konteksti on liian pitkä. Aloita uusi keskustelu." | |
| return | |
| # Enhanced generation - now relies on tokenizer defaults for pad/eos tokens | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'max_new_tokens': max_new, # Uses GEN_MAX_NEW (800) | |
| 'do_sample': False, | |
| 'temperature': 0.1, # Kept for reference (ignored when do_sample=False) | |
| 'top_p': 0.9, # Kept for reference (ignored when do_sample=False) | |
| 'repetition_penalty': 1.2, # Recommended by Ahma team to prevent repetition | |
| 'streamer': streamer, | |
| } | |
| # Start generation in thread with proper error handling | |
| thread = threading.Thread( | |
| target=safe_generate, | |
| args=(model, generation_kwargs) | |
| ) | |
| thread.start() | |
| # Stream with timeout protection | |
| partial, last = "", 0.0 | |
| timeout_start = time.time() | |
| try: | |
| for token in streamer: | |
| # Timeout protection | |
| if time.time() - timeout_start > 90: # 90 second timeout (longer for 4B model) | |
| break | |
| if token is None: | |
| continue | |
| partial += token | |
| now = time.time() | |
| if now - last > 0.08: | |
| yield partial | |
| last = now | |
| except Exception as e: | |
| print(f"Streaming error: {e}") | |
| partial = "Anteeksi, vastauksen generoinnissa tapahtui virhe." | |
| yield partial # final result | |
| # Wait for thread with timeout (generous to allow for slow generation) | |
| thread.join(timeout=45) | |
| except torch.cuda.OutOfMemoryError: | |
| yield "GPU-muisti loppui. Aloita uusi keskustelu tai yritä lyhyempää kysymystä." | |
| # Clear cache but keep model on GPU | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"Critical error in respond: {e}") | |
| yield "Anteeksi, tapahtui odottamaton virhe. Yritä uudelleen tai aloita uusi keskustelu." | |
| finally: | |
| # Minimal cleanup - no need to move model around | |
| # Just collect garbage and optionally clear cache | |
| gc.collect() | |
| # ---------------- UI ---------------- | |
| # Create a warm, earthy theme with Arial font - only using valid Gradio parameters | |
| # NOTE: In Gradio 6.x, theme and css are passed to launch() instead of Blocks() | |
| theme = gr.themes.Default( | |
| primary_hue=gr.themes.colors.orange, # Warm orange for primary elements | |
| secondary_hue=gr.themes.colors.amber, # Complementary amber for secondary elements | |
| neutral_hue=gr.themes.colors.stone, # Stone/beige for neutral, earthy feel | |
| font=[ | |
| "Arial", | |
| "Helvetica", | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif" | |
| ], # Arial as primary font with fallbacks | |
| font_mono=[ | |
| "Monaco", | |
| "Consolas", | |
| "ui-monospace", | |
| "monospace" | |
| ] | |
| ) | |
| # Custom CSS - passed to launch() in Gradio 6.x | |
| custom_css = """ | |
| .custom-textbox textarea { | |
| background-color: #FDF5E6 !important; | |
| border: 1px solid #CD853F !important; | |
| border-radius: 6px !important; | |
| } | |
| /* Style the textbox in ChatInterface */ | |
| .custom-chat-textbox input { | |
| background-color: #FDF5E6 !important; | |
| border: 1px solid #CD853F !important; | |
| border-radius: 6px !important; | |
| } | |
| /* Style the submit button to be rectangular */ | |
| .custom-chat-textbox button[type="submit"] { | |
| border-radius: 4px !important; | |
| background-color: #CD853F !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 8px 16px !important; | |
| } | |
| .custom-chat-textbox button[type="submit"]:hover { | |
| background-color: #B8860B !important; | |
| } | |
| /* Also target any button with submit styling in the textbox */ | |
| .custom-chat-textbox button { | |
| border-radius: 4px !important; | |
| } | |
| h1 { | |
| text-align: center !important; | |
| color: #8B4513 !important; | |
| } | |
| .centered-title { | |
| text-align: center !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| } | |
| """ | |
| # Create the interface using Blocks | |
| # NOTE: In Gradio 6.x, theme and css moved from here to launch() | |
| with gr.Blocks(title="Finnish Dental QA v3") as demo: | |
| # Use a column to constrain the width | |
| with gr.Column(scale=1, min_width=600, elem_id="main-container"): | |
| with gr.Row(): | |
| gr.Column(scale=1, min_width=50) # Left spacer | |
| with gr.Column(scale=6, min_width=400): | |
| # Centered title using HTML | |
| gr.HTML(""" | |
| <div class="centered-title"> | |
| <h1 style="text-align: center; margin: 0; padding: 0;">Finnish Dental QA v3</h1> | |
| </div> | |
| """) | |
| # Warning notice in accordion title with period after first sentence | |
| with gr.Accordion("VAIN TUTKIMUS- JA TESTAUSKÄYTTÖÖN — EI KLIINISIÄ PÄÄTÖKSIÄ VARTEN. KÄYTTÄMÄLLÄ PALVELUA HYVÄKSYT EHDOT. LUE LISÄÄ...", open=False): | |
| gr.HTML(""" | |
| <div style="font-family: Arial, Helvetica, sans-serif; font-size: 0.85rem; line-height: 1.35; color: #8B4513;"> | |
| <strong>Tarkoitus:</strong> Tämä on maksuton tekninen kokeilu Finnish Dental QA v3 -tekoälymallille. | |
| Mallilla voi testata keskustelua hammaslääketieteellisistä aiheista suomeksi.<br><br> | |
| <a href="https://huggingface.co/ducklingcodehouse/Finnish-DentalQA-v3" target="_blank" | |
| style="color: #8B4513;"><strong>Mallin kortti.</strong></a> | |
| <br><br> | |
| <strong>Varoitukset:</strong><br> | |
| • Malli voi antaa virheellisiä vastauksia ja hallusinoida lääketieteellisiä faktoja, hoitomenetelmiä tai lääkeaineiden yhteisvaikutuksia.<br> | |
| • Mallia ei ole kliinisesti validoitu.<br> | |
| • Malli on koulutettu synteettisellä aineistolla, joka mallintaa hammaslääkäreiden välisiä keskusteluja, ja aineisto voi sisältää vääristymiä tai rajoitteita.<br> | |
| • Malli voi hallusinoida erityisesti hammaslääketieteen potilastapausten ulkopuolisissa aiheissa, joihin sitä ei ole pääasiallisesti koulutettu.<br> | |
| • Kehittäjä ei vastaa mallin tuottamasta sisällöstä tai ota vastuuta palvelun käytöstä.<br><br> | |
| <strong>Yhteystiedot ja tietosuoja:</strong><br> | |
| • Palvelun ylläpitäjä: Heikki Saxén / Duckling Codehouse Oy (<a href="mailto:heikki@duckling.fi" style="color: #8B4513;">heikki@duckling.fi</a>, <a href="https://huggingface.co/ducklingcodehouse" target="_blank" style="color: #8B4513;">Hugging Face -profiili</a>)<br> | |
| • Palvelun toimii Hugging Face -alustalla. Sovellettavat <a href="https://huggingface.co/terms-of-service" target="_blank" style="color: #8B4513;">käyttöehdot</a> ja <a href="https://huggingface.co/privacy" target="_blank" style="color: #8B4513;">tietosuojakäytäntö</a>.<br> | |
| • Keskusteluja ei tallenneta pysyvästi palvelimelle.<br> | |
| • Älä jaa arkaluonteisia henkilötietoja tai potilastietoja.<br> | |
| • Käyttämällä palvelua hyväksyt ehdot. | |
| </div> | |
| """) | |
| # Example question with note - always visible | |
| gr.HTML(""" | |
| <div style="font-family: Arial, Helvetica, sans-serif; margin-top: 1rem; padding: 0.75rem; | |
| background-color: #FDF5E6; border-radius: 6px; border-left: 4px solid #CD853F;"> | |
| <em style="color: #654321; font-size: 0.85rem;">Huomaa, että malli muistaa aiemman keskustelun rajatusti. Jos vaihdat aihetta kokonaan, on suositeltavaa painaa keskusteluikkunan oikean yläkulman roskakori-kuvaketta ja tyhjentää edellinen keskustelu.</em> | |
| <br><br> | |
| <strong style="color: #8B4513;">Esimerkkikysymys</strong><br> | |
| <em style="color: #654321; font-size: 0.85rem;">60-vuotias mies, jolla tyypin 2 diabetes (HbA1c 7,8 %) ja tupakointihistoria (lopettanut 2 vuotta sitten), tarvitsee 24-alueen hampaan poiston akuutin parodontiitin vuoksi. CBCT:ssä bukkaalinen luu on resorboitunut, jäljellä 4 mm vertikaalista ja 5 mm horisontaalista luuta. Potilas toivoo välitöntä implantointia. Onko välitön implantointi ja samanaikainen GBR perusteltua tässä tilanteessa, ja mitä riskitekijöitä ja erityishuomioita ottaisit huomioon suunnittelussa ja jälkihoidossa?</em> | |
| </div> | |
| """) | |
| # ChatInterface with Finnish buttons and placeholder text | |
| # NOTE: In Gradio 6.x, show_share_button removed from Chatbot | |
| gr.ChatInterface( | |
| fn=respond, | |
| show_progress="full", | |
| concurrency_limit=CONCURRENCY_LIMIT, | |
| textbox=gr.Textbox( | |
| placeholder="Kirjoita kysymys", | |
| container=False, | |
| scale=7, | |
| submit_btn="Lähetä", | |
| stop_btn="Keskeytä", | |
| elem_classes=["custom-chat-textbox"] | |
| ), | |
| chatbot=gr.Chatbot( | |
| show_label=False, | |
| render_markdown=True | |
| ) | |
| ) | |
| gr.Column(scale=1, min_width=50) # Right spacer | |
| # Print startup message right before launch | |
| print("READY TO RUN, PLEASE CLOSE THIS STARTER CONSOLE") | |
| if __name__ == "__main__": | |
| # Gradio 6.x: theme and css passed to launch() instead of Blocks() | |
| demo.queue().launch(theme=theme, css=custom_css, ssr_mode=False) | |