Remostart commited on
Commit
c82de0f
·
verified ·
1 Parent(s): 00275e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -1
app.py CHANGED
@@ -1 +1,204 @@
1
- import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_NAME = "ubiodee/Plutus_Tutor_new" # ------------ Model and Tokenizer cache ------------ _TOKENIZER = None _MODEL = None def get_tokenizer(): global _TOKENIZER if _TOKENIZER is None: try: _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if _TOKENIZER.pad_token_id is None: if _TOKENIZER.eos_token_id is not None: _TOKENIZER.pad_token = _TOKENIZER.eos_token _TOKENIZER.pad_token_id = _TOKENIZER.eos_token_id else: _TOKENIZER.add_special_tokens({"pad_token": "<|endoftext|>"}) _TOKENIZER.pad_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>") logger.info(f"Set pad_token_id: {_TOKENIZER.pad_token_id}") if _TOKENIZER.eos_token_id is None: _TOKENIZER.eos_token = "<|endoftext|>" _TOKENIZER.eos_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>") logger.info(f"Set eos_token_id: {_TOKENIZER.eos_token_id}") logger.info("Tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load tokenizer: {type(e).__name__}: {e}") raise return _TOKENIZER def get_model(): global _MODEL if _MODEL is None: try: logger.info("Loading model on CPU with FP16") _MODEL = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map="cpu", low_cpu_mem_usage=True, ) _MODEL.eval() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {type(e).__name__}: {e}") raise return _MODEL # ------------ Prompt builder ------------ def build_instructions(personality, level, topic): return ( f"Plutus AI tutor for a {personality} learner at {level} level. " f"Explain {topic} in a tone that fits learners personality and tech level with examples. " "Keep it 250–500 words. End with 'Takeaway:'." ) def build_model_input(tokenizer, personality, level, topic): user_msg = build_instructions(personality, level, topic) # Use a simple text prompt instead of chat template return f"System: You are a personalised Cardano Plutus tutor, your job is to make Plutus easy to learn based on the different learner personalities, you are to adapt your teaching style according to the learner.\nUser: {user_msg}\nAssistant:" # ------------ Generation function ------------ def generate(personality, level, topic, max_new_tokens=500): try: tokenizer = get_tokenizer() model = get_model() prompt = build_model_input(tokenizer, personality, level, topic) device = torch.device("cpu") logger.info("Generating on CPU") inputs = tokenizer(prompt, return_tensors="pt") input_len = inputs["input_ids"].shape[1] inputs = {k: v.to(device) for k, v in inputs.items()} with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.3, top_p=0.4, do_sample=True, repetition_penalty=1.3, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, stop_strings=["Takeaway:"] ) gen_ids = outputs[0][input_len:] text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() if not text: text = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() if text.startswith(prompt): text = text[len(prompt):].lstrip() return text if text else "Generation failed. Try regenerating." except Exception as e: logger.error(f"Generation error: {type(e).__name__}: {e}") return f"Error during generation: {str(e)}. Try regenerating or using a smaller model." # ------------ Orchestrator with retry logic ------------ def orchestrator(personality, level, topic, max_retries=3): if not personality or not level or not topic: return "Select personality, expertise, and topic to get an explanation." for attempt in range(max_retries): try: logger.info(f"Generation attempt {attempt + 1}/{max_retries}") return generate(personality, level, topic) except Exception as e: logger.error(f"Attempt {attempt + 1}/{max_retries} failed: {type(e).__name__}: {e}") if attempt == max_retries - 1: return ( "Failed to generate after multiple attempts. " "Click **Regenerate** or try again later. " "If this persists, try a smaller model or check logs for errors." ) # ------------ Gradio UI ------------ with gr.Blocks(theme="default") as iface: gr.Markdown( "## Cardano Plutus AI Assistant\n" "Select **Learning Personality**, **Expertise Level**, and **Topic**, then click **Generate**. " "Note: Generation may take ~1–2 minutes on CPU." ) with gr.Row(): personality = gr.Dropdown( choices=["Dyslexic", "Autistic", "Expressive"], label="Learning Personality", value=None, allow_custom_value=False, scale=1, ) level = gr.Dropdown( choices=["Beginner", "Intermediate", "Advanced"], label="Expertise Level", value=None, allow_custom_value=False, scale=1, ) topic = gr.Dropdown( choices=[ "What is Plutus", "Introduction to Plutus Smart Contracts", "Understanding Cardano Blockchain", "Validator Scripts in Plutus", "Plutus Tx", "Datum and Redeemer", "Time Handling in Plutus", "Off-Chain Code", "On-Chain Constraints", "Plutus Core", "Transaction Validation", "Cardano Node Integration", ], label="Topic", value=None, allow_custom_value=False, scale=2, ) with gr.Row(): generate_btn = gr.Button("Generate") regen = gr.Button("🔁 Regenerate") output = gr.Textbox( label="Model Response", lines=12, interactive=False, show_copy_button=True, placeholder="Your tailored explanation will appear here…", ) generate_btn.click(orchestrator, [personality, level, topic], output, queue=True) regen.click(orchestrator, [personality, level, topic], output, queue=True) # Enable queue iface.queue() if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import logging
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ MODEL_NAME = "ubiodee/Plutus_Tutor_new"
12
+
13
+ # ------------ Model and Tokenizer cache ------------
14
+ _TOKENIZER = None
15
+ _MODEL = None
16
+
17
+ def get_tokenizer():
18
+ global _TOKENIZER
19
+ if _TOKENIZER is None:
20
+ try:
21
+ _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
22
+ if _TOKENIZER.pad_token_id is None:
23
+ if _TOKENIZER.eos_token_id is not None:
24
+ _TOKENIZER.pad_token = _TOKENIZER.eos_token
25
+ _TOKENIZER.pad_token_id = _TOKENIZER.eos_token_id
26
+ else:
27
+ _TOKENIZER.add_special_tokens({"pad_token": "<|endoftext|>"})
28
+ _TOKENIZER.pad_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>")
29
+ logger.info(f"Set pad_token_id: {_TOKENIZER.pad_token_id}")
30
+ if _TOKENIZER.eos_token_id is None:
31
+ _TOKENIZER.eos_token = "<|endoftext|>"
32
+ _TOKENIZER.eos_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>")
33
+ logger.info(f"Set eos_token_id: {_TOKENIZER.eos_token_id}")
34
+ logger.info("Tokenizer loaded successfully")
35
+ except Exception as e:
36
+ logger.error(f"Failed to load tokenizer: {type(e).__name__}: {e}")
37
+ raise
38
+ return _TOKENIZER
39
+
40
+ def get_model():
41
+ global _MODEL
42
+ if _MODEL is None:
43
+ try:
44
+ logger.info("Loading model on CPU with FP16")
45
+ _MODEL = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_NAME,
47
+ torch_dtype=torch.float16,
48
+ device_map="cpu",
49
+ low_cpu_mem_usage=True,
50
+ )
51
+ _MODEL.eval()
52
+ logger.info("Model loaded successfully")
53
+ except Exception as e:
54
+ logger.error(f"Failed to load model: {type(e).__name__}: {e}")
55
+ raise
56
+ return _MODEL
57
+
58
+ # ------------ Prompt builder ------------
59
+ def build_instructions(personality, level, topic):
60
+ return (
61
+ f"Plutus AI tutor for a {personality} learner at {level} level. "
62
+ f"Explain {topic} in a tone that fits learner's personality and tech level with examples. "
63
+ "Keep it 250–500 words. End with 'Takeaway:'."
64
+ )
65
+
66
+ def build_model_input(tokenizer, personality, level, topic):
67
+ user_msg = build_instructions(personality, level, topic)
68
+ return f"System: You are a personalised Cardano Plutus tutor, your job is to make Plutus easy to learn based on different learner personalities, adapt your teaching style accordingly.\nUser: {user_msg}\nAssistant:"
69
+
70
+ # ------------ Generation function ------------
71
+ def generate(personality, level, topic, max_new_tokens=200):
72
+ try:
73
+ tokenizer = get_tokenizer()
74
+ model = get_model()
75
+ prompt = build_model_input(tokenizer, personality, level, topic)
76
+
77
+ device = torch.device("cpu")
78
+ logger.info("Generating on CPU")
79
+ inputs = tokenizer(prompt, return_tensors="pt")
80
+ input_len = inputs["input_ids"].shape[1]
81
+ inputs = {k: v.to(device) for k, v in inputs.items()}
82
+
83
+ start_time = time.time()
84
+ with torch.inference_mode():
85
+ outputs = model.generate(
86
+ **inputs,
87
+ max_new_tokens=max_new_tokens,
88
+ temperature=0.3,
89
+ top_p=0.4,
90
+ do_sample=True,
91
+ repetition_penalty=1.5, # Increased to prevent repetition
92
+ eos_token_id=tokenizer.eos_token_id,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ stop_strings=["Takeaway:"]
95
+ )
96
+ logger.info(f"Generation took {time.time() - start_time:.2f} seconds")
97
+
98
+ gen_ids = outputs[0][input_len:]
99
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
100
+ if not text:
101
+ text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
102
+
103
+ # Remove prompt if present
104
+ if prompt in text:
105
+ text = text.replace(prompt, "").strip()
106
+
107
+ # Truncate at Takeaway
108
+ if "Takeaway:" in text:
109
+ text = text[:text.index("Takeaway:") + len("Takeaway:") + 100].strip()
110
+ elif not text.endswith("Takeaway:"):
111
+ text += "\nTakeaway: (Summary not fully generated due to token limit)."
112
+
113
+ return text if text else "Generation failed. Try regenerating."
114
+ except Exception as e:
115
+ logger.error(f"Generation error: {type(e).__name__}: {e}")
116
+ return f"Error during generation: {str(e)}. Try regenerating or using a smaller model."
117
+
118
+ # ------------ Orchestrator with retry logic ------------
119
+ def orchestrator(personality, level, topic, max_retries=3):
120
+ if not personality or not level or not topic:
121
+ return "Select personality, expertise, and topic to get an explanation."
122
+
123
+ yield "Generating response, please wait (~1–2 minutes on CPU)..."
124
+
125
+ for attempt in range(max_retries):
126
+ try:
127
+ logger.info(f"Generation attempt {attempt + 1}/{max_retries}")
128
+ result = generate(personality, level, topic)
129
+ if result.startswith("Error during generation"):
130
+ raise Exception(result)
131
+ return result
132
+ except Exception as e:
133
+ logger.error(f"Attempt {attempt + 1}/{max_retries} failed: {type(e).__name__}: {e}")
134
+ if attempt == max_retries - 1:
135
+ return (
136
+ "Failed to generate after multiple attempts. "
137
+ "Click **Regenerate** or try again later. "
138
+ "If this persists, try a smaller model or check logs for errors."
139
+ )
140
+
141
+ # ------------ Gradio UI ------------
142
+ with gr.Blocks(theme="default") as iface:
143
+ gr.Markdown(
144
+ "## Cardano Plutus AI Assistant\n"
145
+ "Select **Learning Personality**, **Expertise Level**, and **Topic**, then click **Generate**. "
146
+ "Note: Generation may take ~1–2 minutes on CPU."
147
+ )
148
+
149
+ with gr.Row():
150
+ personality = gr.Dropdown(
151
+ choices=["Dyslexic", "Autistic", "Expressive"],
152
+ label="Learning Personality",
153
+ value=None,
154
+ allow_custom_value=False,
155
+ scale=1,
156
+ )
157
+ level = gr.Dropdown(
158
+ choices=["Beginner", "Intermediate", "Advanced"],
159
+ label="Expertise Level",
160
+ value=None,
161
+ allow_custom_value=False,
162
+ scale=1,
163
+ )
164
+ topic = gr.Dropdown(
165
+ choices=[
166
+ "What is Plutus?",
167
+ "Smart Contracts in Plutus",
168
+ "Cardano Blockchain",
169
+ "What is a Validator Script?",
170
+ "Plutus Tx",
171
+ "Datum and Redeemer",
172
+ "Time Handling in Plutus",
173
+ "Off-Chain Code",
174
+ "On-Chain Constraints",
175
+ "Plutus Core",
176
+ "Transaction Validation",
177
+ "Cardano Node Integration",
178
+ ],
179
+ label="Topic",
180
+ value=None,
181
+ allow_custom_value=False,
182
+ scale=2,
183
+ )
184
+
185
+ with gr.Row():
186
+ generate_btn = gr.Button("Generate")
187
+ regen = gr.Button("🔁 Regenerate")
188
+
189
+ output = gr.Textbox(
190
+ label="Model Response",
191
+ lines=12,
192
+ interactive=False,
193
+ show_copy_button=True,
194
+ placeholder="Your tailored explanation will appear here…",
195
+ )
196
+
197
+ generate_btn.click(orchestrator, [personality, level, topic], output, queue=True)
198
+ regen.click(orchestrator, [personality, level, topic], output, queue=True)
199
+
200
+ # Enable queue
201
+ iface.queue()
202
+
203
+ if __name__ == "__main__":
204
+ iface.launch(server_name="0.0.0.0", server_port=7860)