ubiodee commited on
Commit
07e6cbc
Β·
verified Β·
1 Parent(s): 612e097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -29
app.py CHANGED
@@ -1,46 +1,89 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
4
 
5
  # Load model & tokenizer
6
  MODEL_NAME = "ubiodee/Cardano_plutus"
7
 
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
- model.eval()
11
-
12
- if torch.cuda.is_available():
13
- model.to("cuda")
 
 
 
 
 
 
 
 
 
14
 
15
- # Response function
16
- def generate_response(prompt):
17
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
18
 
19
- with torch.no_grad():
20
- outputs = model.generate(
 
 
 
 
 
 
 
 
 
21
  **inputs,
22
- max_new_tokens=250,
23
- temperature=0.1,
24
- top_p=0.1,
25
- do_sample=True,
26
- eos_token_id=tokenizer.eos_token_id,
27
- pad_token_id=tokenizer.pad_token_id,
28
- )
29
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
-
31
- # Remove the prompt from the output to return only the answer
32
- if response.startswith(prompt):
33
- response = response[len(prompt):].strip()
34
-
35
- return response
 
 
 
 
 
 
 
 
36
 
37
  # Gradio UI
38
  demo = gr.Interface(
39
  fn=generate_response,
40
- inputs=gr.Textbox(label="Enter your prompt", lines=4, placeholder="Ask about Plutus..."),
 
 
 
 
41
  outputs=gr.Textbox(label="Model Response"),
42
  title="Cardano Plutus AI Assistant",
43
- description="Ask questions about Plutus smart contracts or Cardano blockchain."
 
44
  )
45
 
46
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import logging
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
  # Load model & tokenizer
12
  MODEL_NAME = "ubiodee/Cardano_plutus"
13
 
14
+ try:
15
+ logger.info("Loading tokenizer...")
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ logger.info("Loading model...")
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_NAME,
20
+ device_map="auto",
21
+ torch_dtype=torch.float16,
22
+ low_cpu_mem_usage=True
23
+ )
24
+ model.eval()
25
+ logger.info("Model and tokenizer loaded successfully.")
26
+ except Exception as e:
27
+ logger.error(f"Error loading model or tokenizer: {str(e)}")
28
+ raise
29
 
30
+ # Prompt template to guide the model (simple, since no model card details)
31
+ def format_prompt(user_prompt):
32
+ return f"User: {user_prompt}\nAssistant:"
33
 
34
+ # Response function with proper streaming
35
+ def generate_response(user_prompt):
36
+ try:
37
+ logger.info("Processing prompt...")
38
+ prompt = format_prompt(user_prompt)
39
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
40
+
41
+ # Use streamer for token-by-token generation
42
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
+
44
+ generation_kwargs = {
45
  **inputs,
46
+ "streamer": streamer,
47
+ "max_new_tokens": 300, # Increased slightly for completeness
48
+ "do_sample": True, # Revert to sampling to avoid repetition
49
+ "temperature": 0.1,
50
+ "top_p": 0.1,
51
+ "eos_token_id": tokenizer.eos_token_id,
52
+ "pad_token_id": tokenizer.pad_token_id
53
+ }
54
+
55
+ # Run generation in a separate thread to avoid blocking
56
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
57
+ thread.start()
58
+
59
+ generated_text = ""
60
+ for new_text in streamer:
61
+ generated_text += new_text
62
+ yield generated_text.strip()
63
+
64
+ logger.info("Response generated successfully.")
65
+ except Exception as e:
66
+ logger.error(f"Error during generation: {str(e)}")
67
+ yield f"Error: {str(e)}"
68
 
69
  # Gradio UI
70
  demo = gr.Interface(
71
  fn=generate_response,
72
+ inputs=gr.Textbox(
73
+ label="Enter your prompt",
74
+ lines=4,
75
+ placeholder="Ask about Plutus or Cardano..."
76
+ ),
77
  outputs=gr.Textbox(label="Model Response"),
78
  title="Cardano Plutus AI Assistant",
79
+ description="Your Personalised Plutus Tutor. Optimized with sampling to avoid repetition.",
80
+ allow_flagging="never"
81
  )
82
 
83
+ # Launch the app
84
+ try:
85
+ logger.info("Launching Gradio interface...")
86
+ demo.launch()
87
+ except Exception as e:
88
+ logger.error(f"Error launching Gradio: {str(e)}")
89
+ raise