Remostart commited on
Commit
f031212
·
verified ·
1 Parent(s): da629fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -193
app.py CHANGED
@@ -1,204 +1,46 @@
1
  import gradio as gr
2
  import torch
3
- import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import logging
6
-
7
- # Set up logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
 
 
11
  MODEL_NAME = "ubiodee/Plutus_Tutor_new"
12
 
13
- # ------------ Model and Tokenizer cache ------------
14
- _TOKENIZER = None
15
- _MODEL = None
16
-
17
- def get_tokenizer():
18
- global _TOKENIZER
19
- if _TOKENIZER is None:
20
- try:
21
- _TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
22
- if _TOKENIZER.pad_token_id is None:
23
- if _TOKENIZER.eos_token_id is not None:
24
- _TOKENIZER.pad_token = _TOKENIZER.eos_token
25
- _TOKENIZER.pad_token_id = _TOKENIZER.eos_token_id
26
- else:
27
- _TOKENIZER.add_special_tokens({"pad_token": "<|endoftext|>"})
28
- _TOKENIZER.pad_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>")
29
- logger.info(f"Set pad_token_id: {_TOKENIZER.pad_token_id}")
30
- if _TOKENIZER.eos_token_id is None:
31
- _TOKENIZER.eos_token = "<|endoftext|>"
32
- _TOKENIZER.eos_token_id = _TOKENIZER.convert_tokens_to_ids("<|endoftext|>")
33
- logger.info(f"Set eos_token_id: {_TOKENIZER.eos_token_id}")
34
- logger.info("Tokenizer loaded successfully")
35
- except Exception as e:
36
- logger.error(f"Failed to load tokenizer: {type(e).__name__}: {e}")
37
- raise
38
- return _TOKENIZER
39
-
40
- def get_model():
41
- global _MODEL
42
- if _MODEL is None:
43
- try:
44
- logger.info("Loading model on CPU with FP16")
45
- _MODEL = AutoModelForCausalLM.from_pretrained(
46
- MODEL_NAME,
47
- torch_dtype=torch.float16,
48
- device_map="cpu",
49
- low_cpu_mem_usage=True,
50
- )
51
- _MODEL.eval()
52
- logger.info("Model loaded successfully")
53
- except Exception as e:
54
- logger.error(f"Failed to load model: {type(e).__name__}: {e}")
55
- raise
56
- return _MODEL
57
-
58
- # ------------ Prompt builder ------------
59
- def build_instructions(personality, level, topic):
60
- return (
61
- f"Plutus AI tutor for a {personality} learner at {level} level. "
62
- f"Explain {topic} in a tone that fits learner's personality and tech level with examples. "
63
- "Keep it 250–500 words. End with 'Takeaway:'."
64
- )
65
-
66
- def build_model_input(tokenizer, personality, level, topic):
67
- user_msg = build_instructions(personality, level, topic)
68
- return f"System: You are a personalised Cardano Plutus tutor, your job is to make Plutus easy to learn based on different learner personalities, adapt your teaching style accordingly.\nUser: {user_msg}\nAssistant:"
69
-
70
- # ------------ Generation function ------------
71
- def generate(personality, level, topic, max_new_tokens=200):
72
- try:
73
- tokenizer = get_tokenizer()
74
- model = get_model()
75
- prompt = build_model_input(tokenizer, personality, level, topic)
76
-
77
- device = torch.device("cpu")
78
- logger.info("Generating on CPU")
79
- inputs = tokenizer(prompt, return_tensors="pt")
80
- input_len = inputs["input_ids"].shape[1]
81
- inputs = {k: v.to(device) for k, v in inputs.items()}
82
-
83
- start_time = time.time()
84
- with torch.inference_mode():
85
- outputs = model.generate(
86
- **inputs,
87
- max_new_tokens=max_new_tokens,
88
- temperature=0.3,
89
- top_p=0.4,
90
- do_sample=True,
91
- repetition_penalty=1.3,
92
- eos_token_id=tokenizer.eos_token_id,
93
- pad_token_id=tokenizer.pad_token_id,
94
- stop_strings=["Takeaway:"]
95
- )
96
- logger.info(f"Generation took {time.time() - start_time:.2f} seconds")
97
-
98
- gen_ids = outputs[0][input_len:]
99
- text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
100
- if not text:
101
- text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
102
-
103
- # Remove prompt if present
104
- if prompt in text:
105
- text = text.replace(prompt, "").strip()
106
-
107
- # Truncate at Takeaway
108
- if "Takeaway:" in text:
109
- text = text[:text.index("Takeaway:") + len("Takeaway:") + 100].strip()
110
- elif not text.endswith("Takeaway:"):
111
- text += "\nTakeaway: (Summary not fully generated due to token limit)."
112
-
113
- return text if text else "Generation failed. Try regenerating."
114
- except Exception as e:
115
- logger.error(f"Generation error: {type(e).__name__}: {e}")
116
- return f"Error during generation: {str(e)}. Try regenerating or using a smaller model."
117
-
118
- # ------------ Orchestrator with retry logic ------------
119
- def orchestrator(personality, level, topic, max_retries=3):
120
- if not personality or not level or not topic:
121
- return "Select personality, expertise, and topic to get an explanation."
122
-
123
- logger.info("Yielding loading message")
124
- yield "Generating response, please wait (~1–2 minutes on CPU)..."
125
-
126
- for attempt in range(max_retries):
127
- try:
128
- logger.info(f"Generation attempt {attempt + 1}/{max_retries}")
129
- result = generate(personality, level, topic)
130
- if result.startswith("Error during generation"):
131
- raise Exception(result)
132
- logger.info("Generation completed successfully")
133
- return result
134
- except Exception as e:
135
- logger.error(f"Attempt {attempt + 1}/{max_retries} failed: {type(e).__name__}: {e}")
136
- if attempt == max_retries - 1:
137
- return (
138
- "Failed to generate after multiple attempts. "
139
- "Click **Regenerate** or try again later. "
140
- "If this persists, try a smaller model or check logs for errors."
141
- )
142
-
143
- # ------------ Gradio UI ------------
144
- with gr.Blocks(theme="default") as iface:
145
- gr.Markdown(
146
- "## Cardano Plutus AI Assistant\n"
147
- "Select **Learning Personality**, **Expertise Level**, and **Topic**, then click **Generate**. "
148
- "Note: Generation may take ~1–2 minutes on CPU."
149
- )
150
-
151
- with gr.Row():
152
- personality = gr.Dropdown(
153
- choices=["Dyslexic", "Autistic", "Expressive"],
154
- label="Learning Personality",
155
- value=None,
156
- allow_custom_value=False,
157
- scale=1,
158
- )
159
- level = gr.Dropdown(
160
- choices=["Beginner", "Intermediate", "Advanced"],
161
- label="Expertise Level",
162
- value=None,
163
- allow_custom_value=False,
164
- scale=1,
165
  )
166
- topic = gr.Dropdown(
167
- choices=[
168
- "What is Plutus?",
169
- "Why should I learn Plutus?",
170
- "How does Plutus work?",
171
- "What are the benefits of learning Plutus?",
172
- "What are some common use cases for Plutus?",
173
- "What are some resources for learning Plutus?",
174
- "What are some best practices for using Plutus?",
175
- "What are some common mistakes to avoid when using Plutus?",
176
- "What are some advanced concepts in Plutus?",
177
- "What are some future developments in Plutus?",
178
- ],
179
- label="Topic",
180
- value=None,
181
- allow_custom_value=False,
182
- scale=2,
183
- )
184
-
185
- with gr.Row():
186
- generate_btn = gr.Button("Generate")
187
- regen = gr.Button("🔁 Regenerate")
188
 
189
- output = gr.Textbox(
190
- label="Model Response",
191
- lines=12,
192
- interactive=False,
193
- show_copy_button=True,
194
- placeholder="Your tailored explanation will appear here…",
195
- )
196
 
197
- generate_btn.click(orchestrator, [personality, level, topic], output, queue=True)
198
- regen.click(orchestrator, [personality, level, topic], output, queue=True)
199
 
200
- # Enable queue
201
- iface.queue()
 
 
 
 
 
 
202
 
203
- if __name__ == "__main__":
204
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
4
 
5
+ # Load model & tokenizer
6
  MODEL_NAME = "ubiodee/Plutus_Tutor_new"
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
+ model.eval()
11
+
12
+ if torch.cuda.is_available():
13
+ model.to("cuda")
14
+
15
+ # Response function
16
+ def generate_response(prompt):
17
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
18
+
19
+ with torch.no_grad():
20
+ outputs = model.generate(
21
+ **inputs,
22
+ max_new_tokens=500,
23
+ temperature=0.3,
24
+ top_p=0.3,
25
+ do_sample=True,
26
+ eos_token_id=tokenizer.eos_token_id,
27
+ pad_token_id=tokenizer.pad_token_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Remove the prompt from the output to return only the answer
32
+ if response.startswith(prompt):
33
+ response = response[len(prompt):].strip()
 
 
 
 
34
 
35
+ return response
 
36
 
37
+ # Gradio UI
38
+ demo = gr.Interface(
39
+ fn=generate_response,
40
+ inputs=gr.Textbox(label="Enter your prompt", lines=4, placeholder="Learn about Plutus..."),
41
+ outputs=gr.Textbox(label="Model Response"),
42
+ title="Cardano Plutus AI Assistant",
43
+ description="Your Personalised Plutus Tutor."
44
+ )
45
 
46
+ demo.launch()