File size: 6,721 Bytes
d4be4e6
 
3061d35
af096d6
d4be4e6
 
 
14b43fc
24082be
 
 
 
b1a7eca
 
24082be
b1a7eca
 
 
 
 
 
24082be
 
 
14b43fc
b1a7eca
af096d6
24082be
 
14b43fc
b1a7eca
 
d4be4e6
 
b1a7eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5aec37
af096d6
d5aec37
 
 
 
 
 
24082be
b1a7eca
af096d6
 
d5aec37
af096d6
 
 
 
 
d5aec37
 
 
 
af096d6
 
 
d5aec37
af096d6
d5aec37
 
af096d6
 
14b43fc
af096d6
 
3061d35
 
 
b1a7eca
d5aec37
 
 
3061d35
14b43fc
3061d35
 
 
24082be
d5aec37
14b43fc
 
b1a7eca
 
 
 
af096d6
d5aec37
af096d6
 
 
 
 
 
 
d5aec37
d4be4e6
d5aec37
 
3061d35
 
d5aec37
 
 
 
 
 
 
 
 
 
 
24082be
14b43fc
d4be4e6
 
af096d6
d5aec37
d4be4e6
af096d6
d4be4e6
 
 
af096d6
 
 
24082be
d4be4e6
 
 
af096d6
 
 
24082be
d4be4e6
 
05955ea
 
 
 
 
 
 
 
 
 
 
 
af096d6
05955ea
af096d6
 
 
24082be
d4be4e6
 
af096d6
d5aec37
3061d35
af096d6
24082be
 
 
 
14b43fc
 
24082be
af096d6
d5aec37
3061d35
060b2bc
d5aec37
060b2bc
af096d6
d4be4e6
d5aec37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

MODEL_NAME = "ubiodee/Plutus_Tutor_new"

# ------------ Tokenizer cache ------------
_TOKENIZER = None
def get_tokenizer():
    global _TOKENIZER
    if _TOKENIZER is None:
        tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
        # Ensure pad/eos exist to avoid generation crashes
        if tok.pad_token_id is None:
            if tok.eos_token_id is not None:
                tok.pad_token = tok.eos_token
            elif tok.bos_token_id is not None:
                tok.pad_token = tok.bos_token
            else:
                tok.add_special_tokens({"pad_token": "[PAD]"})
        _TOKENIZER = tok
    return _TOKENIZER

# ------------ Prompt builder ------------
def build_instructions(personality, level, topic):
    return (
        f"You are a friendly Plutus AI tutor for a {personality} learner at {level} level.\n"
        f"Topic: {topic}\n\n"
        "Explain in a conversational, easy tone with concrete examples.\n"
        "Keep it complete and around 120–160 words.\n"
        "End with a one-line takeaway starting with 'Takeaway:'."
    )

def build_model_input(tokenizer, personality, level, topic):
    user_msg = build_instructions(personality, level, topic)

    if hasattr(tokenizer, "apply_chat_template"):
        messages = [
            {"role": "system", "content": "You are a helpful Cardano Plutus tutor."},
            {"role": "user", "content": user_msg},
        ]
        prompt_str = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return prompt_str
    else:
        return (
            "System: You are a helpful Cardano Plutus tutor.\n\n"
            f"User: {user_msg}\n\nAssistant:"
        )

# ------------ GPU/CPU generation ------------
@spaces.GPU
def generate_on_gpu(personality, level, topic, max_new_tokens=100, min_new_tokens=32):
    # Log GPU availability for debugging
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU device: {torch.cuda.get_device_name(0)}")

    tokenizer = get_tokenizer()
    prompt = build_model_input(tokenizer, personality, level, topic)

    try:
        # Try loading model on GPU with 4-bit quantization
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            load_in_4bit=True,
            device_map="auto",
        )
        device = next(model.parameters()).device
    except Exception as e:
        print(f"GPU loading failed: {e}. Falling back to CPU.")
        # Fallback to CPU with FP16
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,
            device_map="cpu",
        )
        device = torch.device("cpu")
    
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt")
    input_len = inputs["input_ids"].shape[1]
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            min_new_tokens=min_new_tokens,
            temperature=0.5,
            top_p=0.95,
            do_sample=True,
            repetition_penalty=1.05,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    # Decode and clean up
    gen_ids = outputs[0][input_len:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    if not text:
        text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        if text.startswith(prompt):
            text = text[len(prompt):].lstrip()

    # Cleanup
    try:
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass

    return text if text else "Generation failed. Try regenerating or adjusting parameters."

# ------------ Orchestrator with retry logic ------------
def orchestrator(personality, level, topic, max_retries=3):
    if not personality or not level or not topic:
        return "Select your personality, expertise, and topic to get a tailored explanation."
    
    for attempt in range(max_retries):
        try:
            return generate_on_gpu(personality, level, topic)
        except Exception as e:
            print(f"[Attempt {attempt + 1}/{max_retries}] ZeroGPU error: {type(e).__name__}: {e}")
            if attempt == max_retries - 1:
                return (
                    "GPU was not available after multiple attempts. "
                    "Click **Regenerate** or try again later."
                )

# ------------ Gradio UI ------------
with gr.Blocks(theme="default") as iface:
    gr.Markdown(
        "## Cardano Plutus AI Assistant\n"
        "Pick your **Learning Personality**, **Expertise Level**, and **Topic**, then click **Generate**."
    )

    with gr.Row():
        personality = gr.Dropdown(
            choices=["Dyslexic", "Autistic", "Expressive"],
            label="Learning Personality",
            value=None,
            allow_custom_value=False,
            scale=1,
        )
        level = gr.Dropdown(
            choices=["Beginner", "Intermediate", "Advanced"],
            label="Expertise Level",
            value=None,
            allow_custom_value=False,
            scale=1,
        )
        topic = gr.Dropdown(
            choices=[
                "Plutus Basics",
                "Smart Contracts",
                "Cardano Blockchain",
                "Validator Scripts",
                "Plutus Tx",
                "Datum and Redeemer",
                "Time Handling in Plutus",
                "Off-Chain Code",
                "On-Chain Constraints",
                "Plutus Core",
                "Transaction Validation",
                "Cardano Node Integration",
            ],
            label="Topic",
            value=None,
            allow_custom_value=False,
            scale=2,
        )

    with gr.Row():
        generate_btn = gr.Button("Generate")
        regen = gr.Button("🔁 Regenerate")

    output = gr.Textbox(
        label="Model Response",
        lines=12,
        interactive=False,
        show_copy_button=True,
        placeholder="Your tailored explanation will appear here…",
    )

    generate_btn.click(orchestrator, [personality, level, topic], output, queue=True)
    regen.click(orchestrator, [personality, level, topic], output, queue=True)

# Enable queue
iface.queue()

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)