rishu834763 commited on
Commit
6a9665a
·
verified ·
1 Parent(s): a7ff14b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -35
app.py CHANGED
@@ -1,60 +1,62 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- from peft import PeftModel, PeftConfig
4
  import gradio as gr
5
 
 
6
  PEFT_ID = "rishu834763/java-explainer-lora"
7
 
8
- # Get base model name
9
- config = PeftConfig.from_pretrained(PEFT_ID)
10
- base = config.base_model_name_or_path
11
 
12
- # Load base model in 4-bit
13
  model = AutoModelForCausalLM.from_pretrained(
14
- base,
15
  device_map="auto",
 
16
  torch_dtype=torch.bfloat16,
17
- load_in_4bit=True,
18
  )
19
 
20
- # Load LoRA weights on top BUT DO NOT MERGE (this is the trick!)
21
  model = PeftModel.from_pretrained(model, PEFT_ID)
22
 
23
- # Tokenizer
24
- tokenizer = AutoTokenizer.from_pretrained(base)
25
  if tokenizer.pad_token is None:
26
  tokenizer.pad_token = tokenizer.eos_token
27
 
28
- # Pipeline
29
- pipe = pipeline(
30
- "text-generation",
31
- model=model,
32
- tokenizer=tokenizer,
33
- max_new_tokens=1024,
34
- temperature=0.6,
35
- do_sample=True,
36
- top_p=0.9,
37
- )
38
-
39
  def chat(message, history):
40
  messages = []
41
- for user_msg, assistant_msg in history:
42
- messages.append({"role": "user", "content": user_msg})
43
- if assistant_msg:
44
- messages.append({"role": "assistant", "content": assistant_msg})
45
  messages.append({"role": "user", "content": message})
46
 
47
- output = pipe(messages)[0]["generated_text"]
48
- return output[-1]["content"]
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  gr.ChatInterface(
51
  chat,
52
- title="Java Explainer Your Model (Running!)",
53
- description="100% your fine-tuned LoRA · No OpenAI · Instant start",
54
  examples=[
55
- "Explain this Java code: public static void main(String[] args) { System.out.println(\"Hello\"); }",
56
- "What does public static void main mean?",
57
- "Difference between String and StringBuilder?",
58
  ],
59
  cache_examples=False,
60
- ).queue().launch()
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
  import gradio as gr
5
 
6
+ # Your LoRA
7
  PEFT_ID = "rishu834763/java-explainer-lora"
8
 
9
+ # Load base model in 8-bit instead of 4-bit (much faster & more stable cold start on free tier)
10
+ base_model = "mistralai/Mistral-7B-Instruct-v0.2"
 
11
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
+ base_model,
14
  device_map="auto",
15
+ load_in_8bit=True, # ← 8-bit instead of 4-bit = instant start
16
  torch_dtype=torch.bfloat16,
 
17
  )
18
 
19
+ # Apply your LoRA (no merge = super fast)
20
  model = PeftModel.from_pretrained(model, PEFT_ID)
21
 
22
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
 
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ # Proper generation function
 
 
 
 
 
 
 
 
 
 
27
  def chat(message, history):
28
  messages = []
29
+ for user, assistant in history:
30
+ messages.append({"role": "user", "content": user})
31
+ if assistant:
32
+ messages.append({"role": "assistant", "content": assistant})
33
  messages.append({"role": "user", "content": message})
34
 
35
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
36
+
37
+ output_ids = model.generate(
38
+ input_ids,
39
+ max_new_tokens=1024,
40
+ temperature=0.6,
41
+ top_p=0.9,
42
+ do_sample=True,
43
+ repetition_penalty=1.1,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ pad_token_id=tokenizer.eos_token_id,
46
+ )
47
+
48
+ response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
49
+ return response
50
+
51
+ # Interface
52
  gr.ChatInterface(
53
  chat,
54
+ title="Java Explainer Live Now",
55
+ description="Your own fine-tuned model · Starts in seconds · No OpenAI",
56
  examples=[
57
+ "Explain this Java code simply: public static void main(String[] args) { System.out.println(\"Hello\"); }",
58
+ "What is the difference between == and .equals()?",
59
+ "Why do we need the 'static' keyword in main()?",
60
  ],
61
  cache_examples=False,
62
+ ).queue(max_size=30).launch()