Shriti09 commited on
Commit
025b757
·
verified ·
1 Parent(s): b15418e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -60
app.py CHANGED
@@ -1,83 +1,75 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  import gradio as gr
5
 
6
- # Model Names
7
- BASE_MODEL_NAME = "microsoft/phi-2"
8
- ADAPTER_REPO = "Shriti09/Microsoft-Phi-QLora"
9
 
10
- # Load tokenizer and model
11
- print("Loading tokenizer...")
12
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
13
- tokenizer.pad_token = tokenizer.eos_token
14
 
15
- print("Loading base model...")
16
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
17
 
18
- print("Loading LoRA adapter...")
19
- model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
20
 
21
- # Merge adapter into the base model
22
- model = model.merge_and_unload()
23
- model.eval()
24
 
25
- # Function to generate responses
26
- def generate_response(message, chat_history, temperature, top_p, max_tokens):
27
- # Combine history with the new message
 
 
 
 
28
  full_prompt = ""
29
- for user_msg, bot_msg in chat_history:
30
  full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
31
  full_prompt += f"User: {message}\nAI:"
32
 
33
- # Tokenize and generate
34
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_length=len(inputs["input_ids"][0]) + max_tokens,
38
- do_sample=True,
39
- temperature=temperature,
40
- top_p=top_p,
41
- pad_token_id=tokenizer.eos_token_id
42
- )
43
-
44
- # Decode and extract the AI response
 
 
45
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
- # Only return the new part of the response
47
  response = response.split("AI:")[-1].strip()
48
 
49
- # Update history
50
- chat_history.append((message, response))
51
- return chat_history, chat_history
52
 
53
- # Gradio UI with Blocks
54
- with gr.Blocks() as demo:
55
- gr.Markdown("<h1><center>🤖 Phi-2 QLoRA Chatbot</center></h1>")
56
- gr.Markdown("Chat with Microsoft Phi-2 fine-tuned using QLoRA adapters!")
57
 
58
  chatbot = gr.Chatbot()
59
- msg = gr.Textbox(placeholder="Ask me something...", label="Your Message")
60
- clear = gr.Button("🗑️ Clear Chat")
61
 
62
- # Add sliders for controlling generation behavior
63
- with gr.Row():
64
- temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
65
- top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
66
- max_tokens_slider = gr.Slider(64, 1024, value=256, step=64, label="Max Tokens")
67
-
68
- # State to hold chat history
69
  state = gr.State([])
70
 
71
- # On send message
72
- def on_message(message, history, temperature, top_p, max_tokens):
73
- return generate_response(message, history, temperature, top_p, max_tokens)
74
-
75
- # Button actions
76
- msg.submit(on_message,
77
- [msg, state, temp_slider, top_p_slider, max_tokens_slider],
78
- [chatbot, state])
79
-
80
- clear.click(lambda: ([], []), None, [chatbot, state])
81
 
82
- # Launch the Gradio app
83
- demo.launch()
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import PeftModel
4
  import gradio as gr
5
 
6
+ # Use GPU if available
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
9
+ # Base model and adapter paths
10
+ base_model_name = "microsoft/phi-2" # Pull from HF Hub directly
11
+ adapter_path = "./phi2-qlora-adapter" # Your uploaded adapter folder in Space repo
 
12
 
13
+ print("🔧 Loading base model...")
14
+ base_model = AutoModelForCausalLM.from_pretrained(
15
+ base_model_name,
16
+ device_map="auto",
17
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
+ )
19
 
20
+ print("🔧 Loading LoRA adapter...")
21
+ adapter_model = PeftModel.from_pretrained(base_model, adapter_path)
22
 
23
+ print("🔗 Merging adapter into base model...")
24
+ merged_model = adapter_model.merge_and_unload()
25
+ merged_model.eval()
26
 
27
+ # Load tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
29
+ print("✅ Model ready for inference!")
30
+
31
+ # Chat function with history
32
+ def chat_fn(message, history):
33
+ # Combine conversation history into one prompt
34
  full_prompt = ""
35
+ for user_msg, bot_msg in history:
36
  full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
37
  full_prompt += f"User: {message}\nAI:"
38
 
39
+ # Tokenize inputs
40
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
41
+
42
+ with torch.no_grad():
43
+ outputs = merged_model.generate(
44
+ **inputs,
45
+ max_new_tokens=150,
46
+ do_sample=True,
47
+ temperature=0.7,
48
+ top_p=0.9,
49
+ pad_token_id=tokenizer.eos_token_id
50
+ )
51
+
52
+ # Decode and return only the AI's latest response
53
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
54
  response = response.split("AI:")[-1].strip()
55
 
56
+ # Append to history
57
+ history.append((message, response))
58
+ return history, history
59
 
60
+ # Gradio UI
61
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
62
+ gr.Markdown("<h1>🧠 Phi-2 QLoRA Chatbot</h1>")
 
63
 
64
  chatbot = gr.Chatbot()
65
+ message = gr.Textbox(label="Your message:")
66
+ clear = gr.Button("Clear chat")
67
 
 
 
 
 
 
 
 
68
  state = gr.State([])
69
 
70
+ message.submit(chat_fn, [message, state], [chatbot, state])
71
+ clear.click(lambda: [], None, chatbot)
72
+ clear.click(lambda: [], None, state)
 
 
 
 
 
 
 
73
 
74
+ # Run with queue for multiple users
75
+ demo.queue(concurrency_count=2).launch()