rishu834763 commited on
Commit
6729932
Β·
verified Β·
1 Parent(s): 94da46e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -50
app.py CHANGED
@@ -1,68 +1,105 @@
 
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import PeftModel
4
  import gradio as gr
5
 
6
- # Direct base model (no auto-detection needed)
7
- BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
8
- PEFT_ID = "rishu834763/java-explainer-lora"
 
 
9
 
10
- # Load in 8-bit β†’ super stable & fast cold start on free tier
11
- model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
12
  BASE_MODEL,
13
- device_map="auto",
14
- load_in_8bit=True,
15
  torch_dtype=torch.bfloat16,
 
16
  )
17
 
18
- # Load your LoRA on top (no merge = instant)
19
- model = PeftModel.from_pretrained(model, PEFT_ID)
20
 
21
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
22
  tokenizer.pad_token = tokenizer.eos_token
23
 
24
- def chat(message, history):
25
- messages = []
26
- for user_msg, bot_msg in history:
27
- messages.append({"role": "user", "content": user_msg})
28
- if bot_msg:
29
- messages.append({"role": "assistant", "content": bot_msg})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  messages.append({"role": "user", "content": message})
31
 
32
- input_ids = tokenizer.apply_chat_template(
33
  messages,
 
34
  add_generation_prompt=True,
35
- return_tensors="pt"
36
- ).to(model.device)
37
-
38
- terminators = [
39
- tokenizer.eos_token_id,
40
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
41
- ]
42
-
43
- output_ids = model.generate(
44
- input_ids,
45
- max_new_tokens=1024,
46
- do_sample=True,
47
- temperature=0.6,
48
- top_p=0.9,
49
- eos_token_id=terminators,
50
- pad_token_id=tokenizer.eos_token_id,
51
  )
52
 
53
- response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
54
- return response
55
-
56
- # Minimal interface – starts instantly
57
- gr.ChatInterface(
58
- chat,
59
- title="Java Explainer – Your Model",
60
- examples=[
61
- "Explain this Java code simply:\npublic class Test {\n public static void main(String[] args) {\n System.out.println(\"Hello\");\n }\n}",
62
- "What is the difference between HashMap and Hashtable?",
63
- "Why main method is public static void?",
64
- ],
65
- cache_examples=False,
66
- submit_btn="Send",
67
- retry_btn=None, # removes the retry that sometimes causes blank replies
68
- ).queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import gradio as gr
6
 
7
+ # ===================================
8
+ # 1. Model & LoRA (your exact repo)
9
+ # ===================================
10
+ BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" # do NOT change
11
+ LORA_ADAPTER = "rishu834763/java-explainer-lora" # ← your LoRA
12
 
13
+ # 4-bit quantization (fits on 1Γ—A100 40/80GB, 4090 24GB, T4 16GB with some offloading)
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_compute_dtype=torch.bfloat16,
18
+ bnb_4bit_use_double_quant=True,
19
+ )
20
+
21
+ print("Loading base model (Llama-3-8B-Instruct 4-bit)...")
22
+ base_model = AutoModelForCausalLM.from_pretrained(
23
  BASE_MODEL,
24
+ quantization_config=quantization_config,
25
+ device_map="auto", # auto-offload to CPU if needed
26
  torch_dtype=torch.bfloat16,
27
+ trust_remote_code=True,
28
  )
29
 
30
+ print("Loading your LoRA adapter...")
31
+ model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
32
 
33
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
+ # ===================================
37
+ # 2. Inference pipeline
38
+ # ===================================
39
+ pipe = torch.pipeline(
40
+ "text-generation",
41
+ model=model,
42
+ tokenizer=tokenizer,
43
+ max_new_tokens=1024,
44
+ temperature=0.3,
45
+ top_p=0.95,
46
+ do_sample=True,
47
+ repetition_penalty=1.15,
48
+ return_full_text=False,
49
+ )
50
+
51
+ # System prompt tuned for Java explanations
52
+ SYSTEM_PROMPT = "You are an expert Java teacher. Explain concepts clearly, provide code examples, and answer in a concise but complete way."
53
+
54
+ def chat(message: str, history):
55
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
56
+
57
+ # Convert Gradio history β†’ Llama-3 format
58
+ for user, assistant in history:
59
+ messages.append({"role": "user", "content": user})
60
+ if assistant:
61
+ messages.append({"role": "assistant", "content": assistant})
62
+
63
  messages.append({"role": "user", "content": message})
64
 
65
+ prompt = tokenizer.apply_chat_template(
66
  messages,
67
+ tokenize=False,
68
  add_generation_prompt=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
 
71
+ output = pipe(prompt)[0]["generated_text"]
72
+ return output
73
+
74
+ # ===================================
75
+ # 3. Modern Gradio UI (2025)
76
+ # ===================================
77
+ with gr.Blocks(theme=gr.themes.Soft(), title="Java Explainer (Llama-3-8B + Your LoRA)") as demo:
78
+ gr.Markdown("# πŸ§‘β€πŸ’» Java Explainer\nPowered by **rishu834763/java-explainer-lora** on Llama-3-8B-Instruct")
79
+
80
+ chatbot = gr.Chatbot(height=620)
81
+ msg = gr.Textbox(
82
+ placeholder="Ask anything about Java (e.g. 'Explain Spring Boot @Autowired with example')",
83
+ label="Your question",
84
+ container=False,
85
+ )
86
+
87
+ with gr.Row():
88
+ send = gr.Button("Send πŸš€", variant="primary")
89
+ clear = gr.Button("Clear πŸ—‘οΈ")
90
+
91
+ with gr.Row():
92
+ retry = gr.Button("πŸ”„ Retry")
93
+ undo = gr.Button("β†Ά Undo")
94
+
95
+ # Events
96
+ send.click(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
97
+ msg.submit(chat, [msg, chatbot], [msg, chatbot]).then(lambda: "", outputs=msg)
98
+ clear.click(lambda: None, None, chatbot, queue=False)
99
+ retry.click(lambda h: h[:-1], chatbot, chatbot, queue=False)
100
+ undo.click(lambda h: h[:-1], chatbot, chatbot, queue=False)
101
+
102
+ demo.queue(max_size=64).launch(
103
+ server_name="0.0.0.0",
104
+ server_port=7860,
105
+ )