rishu834763 commited on
Commit
00d2932
·
verified ·
1 Parent(s): e30f5ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -16
app.py CHANGED
@@ -3,44 +3,70 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from peft import PeftModel, PeftConfig
4
  import gradio as gr
5
 
6
- # Your LoRA
7
  PEFT_ID = "rishu834763/java-explainer-lora"
8
 
 
9
  config = PeftConfig.from_pretrained(PEFT_ID)
10
  base = config.base_model_name_or_path
11
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  base,
14
  torch_dtype=torch.bfloat16,
15
  device_map="auto",
16
- load_in_4bit=True
17
  )
 
 
18
  model = PeftModel.from_pretrained(model, PEFT_ID)
19
  model = model.merge_and_unload()
20
 
 
21
  tokenizer = AutoTokenizer.from_pretrained(base)
22
  if tokenizer.pad_token is None:
23
  tokenizer.pad_token = tokenizer.eos_token
24
 
25
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.6, do_sample=True)
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  def chat(message, history):
28
- msgs = []
29
- for u, a in history:
30
- msgs.append({"role": "user", "content": u})
31
- if a: msgs.append({"role": "assistant", "content": a})
32
- msgs.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
33
 
34
- out = pipe(msgs)
35
- return out[0]["generated_text"][-1]["content"]
36
 
 
37
  gr.ChatInterface(
38
  chat,
39
- title="Java Explainer – Your Own Model (No OpenAI)",
40
- description="This is 100% your fine-tuned Java LoRA running locally on Hugging Face",
41
  examples=[
42
- "Explain this Java code: public static void main(String[] args) { System.out.println(\"Hello\"); }",
43
- "What does @Override do in Java?",
44
- "Difference between HashMap and Hashtable?"
45
- ]
 
 
46
  ).queue().launch()
 
3
  from peft import PeftModel, PeftConfig
4
  import gradio as gr
5
 
 
6
  PEFT_ID = "rishu834763/java-explainer-lora"
7
 
8
+ # Load config to know the base model
9
  config = PeftConfig.from_pretrained(PEFT_ID)
10
  base = config.base_model_name_or_path
11
 
12
+ # Load model (4-bit for free tier)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  base,
15
  torch_dtype=torch.bfloat16,
16
  device_map="auto",
17
+ load_in_4bit=True,
18
  )
19
+
20
+ # Apply your LoRA and merge
21
  model = PeftModel.from_pretrained(model, PEFT_ID)
22
  model = model.merge_and_unload()
23
 
24
+ # Tokenizer
25
  tokenizer = AutoTokenizer.from_pretrained(base)
26
  if tokenizer.pad_token is None:
27
  tokenizer.pad_token = tokenizer.eos_token
28
 
29
+ # Pipeline
30
+ pipe = pipeline(
31
+ "text-generation",
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ max_new_tokens=1024,
35
+ temperature=0.6,
36
+ do_sample=True,
37
+ top_p=0.9,
38
+ repetition_penalty=1.1,
39
+ )
40
 
41
+ # ========= FIXED CHAT FUNCTION =========
42
  def chat(message, history):
43
+ messages = []
44
+
45
+ # Rebuild proper alternating messages, skipping empty assistant replies
46
+ for user_msg, assistant_msg in history:
47
+ messages.append({"role": "user", "content": user_msg})
48
+ if assistant_msg: # ← only add assistant if it's not empty/None
49
+ messages.append({"role": "assistant", "content": assistant_msg})
50
+
51
+ # Add the new user message
52
+ messages.append({"role": "user", "content": message})
53
+
54
+ # Generate
55
+ output = pipe(messages)[0]["generated_text"]
56
 
57
+ # Extract only the last assistant reply
58
+ return output[-1]["content"]
59
 
60
+ # ========= GRADIO INTERFACE =========
61
  gr.ChatInterface(
62
  chat,
63
+ title="Java Explainer – Your Own Fine-Tuned Model",
64
+ description="Powered 100% by your LoRA on Mistral-7B-Instruct-v0.2",
65
  examples=[
66
+ "Explain this Java code in simple terms:\npublic class Hello {\n public static void main(String[] args) {\n System.out.println(\"Hello World!\");\n }\n}",
67
+ "What is the difference between ArrayList and LinkedList?",
68
+ "Why do we use the synchronized keyword?",
69
+ "Convert this Python factorial function to Java",
70
+ ],
71
+ cache_examples=False, # ← this was causing the caching error too
72
  ).queue().launch()