rishu834763 commited on
Commit
cd50342
·
verified ·
1 Parent(s): dc1ec25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -24
app.py CHANGED
@@ -1,39 +1,85 @@
1
- # app.py
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline # ← pipeline is here!
4
  from peft import PeftModel
5
  import gradio as gr
6
 
7
- # ===================================
8
- BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" # Open, no gate!
9
  LORA_ADAPTER = "rishu834763/java-explainer-lora"
10
 
 
11
  quantization_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
13
  bnb_4bit_quant_type="nf4",
14
  bnb_4bit_compute_dtype=torch.bfloat16,
15
  bnb_4bit_use_double_quant=True,
 
16
  )
17
 
18
- print("Loading Llama-3.1-8B-Instruct 4-bit + your LoRA...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  base_model = AutoModelForCausalLM.from_pretrained(
20
  BASE_MODEL,
21
  quantization_config=quantization_config,
22
- device_map="auto",
23
  torch_dtype=torch.bfloat16,
 
24
  trust_remote_code=True,
25
  )
26
 
 
27
  model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
28
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
 
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
31
- # FIXED: pipeline from transformers, not torch
32
  pipe = pipeline(
33
  "text-generation",
34
  model=model,
35
  tokenizer=tokenizer,
36
- max_new_tokens=1024,
37
  temperature=0.3,
38
  top_p=0.95,
39
  do_sample=True,
@@ -41,32 +87,45 @@ pipe = pipeline(
41
  return_full_text=False,
42
  )
43
 
44
- SYSTEM_PROMPT = "You are an expert Java teacher. Explain concepts clearly with code examples."
45
 
46
  def chat(message: str, history):
47
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
48
- for user, assistant in history:
49
- messages.append({"role": "user", "content": user})
50
- if assistant:
51
- messages.append({"role": "assistant", "content": assistant})
 
 
52
  messages.append({"role": "user", "content": message})
53
 
54
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55
- outputs = pipe(prompt)
56
- return outputs[0]["generated_text"]
 
 
 
 
 
 
57
 
58
- # ===================================
59
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
60
- gr.Markdown("# Java Explainer\nPowered by your LoRA on Llama-3.1-8B-Instruct (4-bit)")
61
- chatbot = gr.Chatbot(height=620)
62
- msg = gr.Textbox(placeholder="Ask anything about Java...", container=False)
 
63
 
64
  with gr.Row():
65
- send = gr.Button("Send 🚀", variant="primary")
66
- clear = gr.Button("Clear")
67
 
68
  send.click(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
69
  msg.submit(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
70
  clear.click(lambda: None, None, chatbot, queue=False)
71
 
72
- demo.queue(max_size=64).launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
1
+ # app.py - Fixed for Low VRAM (November 2025, T4-Compatible)
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
4
  from peft import PeftModel
5
  import gradio as gr
6
 
7
+ # Exact base for your LoRA
8
+ BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
9
  LORA_ADAPTER = "rishu834763/java-explainer-lora"
10
 
11
+ # Enhanced 4-bit config with CPU offload enabled
12
  quantization_config = BitsAndBytesConfig(
13
  load_in_4bit=True,
14
  bnb_4bit_quant_type="nf4",
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_use_double_quant=True,
17
+ llm_int8_enable_fp32_cpu_offload=True, # ← Key fix: Allows CPU offload in 32-bit
18
  )
19
 
20
+ print("Loading Mistral-7B-Instruct-v0.2 (4-bit with CPU offload) + your Java LoRA...")
21
+
22
+ # Custom device_map: Prioritizes GPU, offloads to CPU as needed
23
+ device_map = {
24
+ "model.embed_tokens": 0, # GPU
25
+ "model.layers.0": 0, # GPU for first layers...
26
+ "model.layers.1": 0,
27
+ "model.layers.2": 0,
28
+ "model.layers.3": 0,
29
+ "model.layers.4": 0,
30
+ "model.layers.5": 0,
31
+ "model.layers.6": 0,
32
+ "model.layers.7": 0, # ~Halfway: Switch to CPU for rest
33
+ "model.layers.8": "cpu",
34
+ "model.layers.9": "cpu",
35
+ "model.layers.10": "cpu",
36
+ "model.layers.11": "cpu",
37
+ "model.layers.12": "cpu",
38
+ "model.layers.13": "cpu",
39
+ "model.layers.14": "cpu",
40
+ "model.layers.15": "cpu",
41
+ "model.layers.16": "cpu",
42
+ "model.layers.17": "cpu",
43
+ "model.layers.18": "cpu",
44
+ "model.layers.19": "cpu",
45
+ "model.layers.20": "cpu",
46
+ "model.layers.21": "cpu",
47
+ "model.layers.22": "cpu",
48
+ "model.layers.23": "cpu",
49
+ "model.layers.24": "cpu",
50
+ "model.layers.25": "cpu",
51
+ "model.layers.26": "cpu",
52
+ "model.layers.27": "cpu",
53
+ "model.layers.28": "cpu",
54
+ "model.layers.29": "cpu",
55
+ "model.layers.30": "cpu",
56
+ "model.layers.31": "cpu",
57
+ "model.norm": 0, # GPU
58
+ "lm_head": 0, # GPU
59
+ }
60
+
61
+ # Load base with fixes
62
  base_model = AutoModelForCausalLM.from_pretrained(
63
  BASE_MODEL,
64
  quantization_config=quantization_config,
65
+ device_map=device_map, # ← Custom map: GPU first, CPU fallback
66
  torch_dtype=torch.bfloat16,
67
+ low_cpu_mem_usage=True, # ← Reduces loading RAM spike
68
  trust_remote_code=True,
69
  )
70
 
71
+ # Apply your LoRA (lightweight, won't add much overhead)
72
  model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
75
  tokenizer.pad_token = tokenizer.eos_token
76
 
77
+ # Pipeline (optimized for mixed device)
78
  pipe = pipeline(
79
  "text-generation",
80
  model=model,
81
  tokenizer=tokenizer,
82
+ max_new_tokens=512, # Reduced for speed on low VRAM
83
  temperature=0.3,
84
  top_p=0.95,
85
  do_sample=True,
 
87
  return_full_text=False,
88
  )
89
 
90
+ SYSTEM_PROMPT = "You are an expert Java teacher with 15+ years of experience. Always explain concepts clearly, include clean code examples, and use best practices."
91
 
92
  def chat(message: str, history):
93
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
94
+
95
+ for user_msg, assistant_msg in history:
96
+ messages.append({"role": "user", "content": user_msg})
97
+ if assistant_msg:
98
+ messages.append({"role": "assistant", "content": assistant_msg})
99
+
100
  messages.append({"role": "user", "content": message})
101
 
102
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
+
104
+ output = pipe(prompt)[0]["generated_text"]
105
+ return output
106
+
107
+ # Gradio UI (unchanged)
108
+ with gr.Blocks(theme=gr.themes.Soft(), title="Java Explainer Pro") as demo:
109
+ gr.Markdown("# Java Explainer Pro\nFine-tuned on **rishu834763/java-explainer-lora** + **Mistral-7B-v0.2** (Low-VRAM Optimized)")
110
+ gr.Markdown("Ask anything about Java — from basics to Spring Boot, concurrency, JVM internals, and more!")
111
 
112
+ chatbot = gr.Chatbot(height=600)
113
+ msg = gr.Textbox(
114
+ placeholder="e.g. Explain CompletableFuture with a real-world example",
115
+ label="Your Java Question",
116
+ container=False,
117
+ )
118
 
119
  with gr.Row():
120
+ send = gr.Button("Send", variant="primary", scale=2)
121
+ clear = gr.Button("Clear Chat")
122
 
123
  send.click(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
124
  msg.submit(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
125
  clear.click(lambda: None, None, chatbot, queue=False)
126
 
127
+ demo.queue(max_size=50).launch(
128
+ server_name="0.0.0.0",
129
+ server_port=7860,
130
+ share=True
131
+ )