Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import spaces | |
| MODEL_NAME = "ubiodee/Plutus_Tutor_new" | |
| # ------------ Tokenizer cache ------------ | |
| _TOKENIZER = None | |
| def get_tokenizer(): | |
| global _TOKENIZER | |
| if _TOKENIZER is None: | |
| tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| # Ensure pad/eos exist to avoid generation crashes | |
| if tok.pad_token_id is None: | |
| if tok.eos_token_id is not None: | |
| tok.pad_token = tok.eos_token | |
| elif tok.bos_token_id is not None: | |
| tok.pad_token = tok.bos_token | |
| else: | |
| tok.add_special_tokens({"pad_token": "[PAD]"}) | |
| _TOKENIZER = tok | |
| return _TOKENIZER | |
| # ------------ Prompt builder ------------ | |
| def build_instructions(personality, level, topic): | |
| return ( | |
| f"You are a friendly Plutus AI tutor for a {personality} learner at {level} level.\n" | |
| f"Topic: {topic}\n\n" | |
| "Explain in a conversational, easy tone with concrete examples.\n" | |
| "Keep it complete and around 120–160 words.\n" | |
| "End with a one-line takeaway starting with 'Takeaway:'." | |
| ) | |
| def build_model_input(tokenizer, personality, level, topic): | |
| user_msg = build_instructions(personality, level, topic) | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful Cardano Plutus tutor."}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| prompt_str = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return prompt_str | |
| else: | |
| return ( | |
| "System: You are a helpful Cardano Plutus tutor.\n\n" | |
| f"User: {user_msg}\n\nAssistant:" | |
| ) | |
| # ------------ GPU/CPU generation ------------ | |
| def generate_on_gpu(personality, level, topic, max_new_tokens=100, min_new_tokens=32): | |
| # Log GPU availability for debugging | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU device: {torch.cuda.get_device_name(0)}") | |
| tokenizer = get_tokenizer() | |
| prompt = build_model_input(tokenizer, personality, level, topic) | |
| try: | |
| # Try loading model on GPU with 4-bit quantization | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| load_in_4bit=True, | |
| device_map="auto", | |
| ) | |
| device = next(model.parameters()).device | |
| except Exception as e: | |
| print(f"GPU loading failed: {e}. Falling back to CPU.") | |
| # Fallback to CPU with FP16 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="cpu", | |
| ) | |
| device = torch.device("cpu") | |
| model.eval() | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_len = inputs["input_ids"].shape[1] | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| temperature=0.5, | |
| top_p=0.95, | |
| do_sample=True, | |
| repetition_penalty=1.05, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode and clean up | |
| gen_ids = outputs[0][input_len:] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| if not text: | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| if text.startswith(prompt): | |
| text = text[len(prompt):].lstrip() | |
| # Cleanup | |
| try: | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| return text if text else "Generation failed. Try regenerating or adjusting parameters." | |
| # ------------ Orchestrator with retry logic ------------ | |
| def orchestrator(personality, level, topic, max_retries=3): | |
| if not personality or not level or not topic: | |
| return "Select your personality, expertise, and topic to get a tailored explanation." | |
| for attempt in range(max_retries): | |
| try: | |
| return generate_on_gpu(personality, level, topic) | |
| except Exception as e: | |
| print(f"[Attempt {attempt + 1}/{max_retries}] ZeroGPU error: {type(e).__name__}: {e}") | |
| if attempt == max_retries - 1: | |
| return ( | |
| "GPU was not available after multiple attempts. " | |
| "Click **Regenerate** or try again later." | |
| ) | |
| # ------------ Gradio UI ------------ | |
| with gr.Blocks(theme="default") as iface: | |
| gr.Markdown( | |
| "## Cardano Plutus AI Assistant\n" | |
| "Pick your **Learning Personality**, **Expertise Level**, and **Topic**, then click **Generate**." | |
| ) | |
| with gr.Row(): | |
| personality = gr.Dropdown( | |
| choices=["Dyslexic", "Autistic", "Expressive"], | |
| label="Learning Personality", | |
| value=None, | |
| allow_custom_value=False, | |
| scale=1, | |
| ) | |
| level = gr.Dropdown( | |
| choices=["Beginner", "Intermediate", "Advanced"], | |
| label="Expertise Level", | |
| value=None, | |
| allow_custom_value=False, | |
| scale=1, | |
| ) | |
| topic = gr.Dropdown( | |
| choices=[ | |
| "Plutus Basics", | |
| "Smart Contracts", | |
| "Cardano Blockchain", | |
| "Validator Scripts", | |
| "Plutus Tx", | |
| "Datum and Redeemer", | |
| "Time Handling in Plutus", | |
| "Off-Chain Code", | |
| "On-Chain Constraints", | |
| "Plutus Core", | |
| "Transaction Validation", | |
| "Cardano Node Integration", | |
| ], | |
| label="Topic", | |
| value=None, | |
| allow_custom_value=False, | |
| scale=2, | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate") | |
| regen = gr.Button("🔁 Regenerate") | |
| output = gr.Textbox( | |
| label="Model Response", | |
| lines=12, | |
| interactive=False, | |
| show_copy_button=True, | |
| placeholder="Your tailored explanation will appear here…", | |
| ) | |
| generate_btn.click(orchestrator, [personality, level, topic], output, queue=True) | |
| regen.click(orchestrator, [personality, level, topic], output, queue=True) | |
| # Enable queue | |
| iface.queue() | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |