lingadevaruhp commited on
Commit
8d219ad
·
verified ·
1 Parent(s): 2c03519

Remove LoRA dependencies, use base Gemma-2-9B model

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -1,38 +1,48 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from peft import PeftModel
3
  import torch
4
  import gradio as gr
5
 
6
  # Load tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
8
 
9
- # Load base model on CPU
10
- base_model = AutoModelForCausalLM.from_pretrained(
11
  "google/gemma-2-9b-it",
12
  torch_dtype=torch.bfloat16,
13
- device_map="cpu",
14
  low_cpu_mem_usage=True
15
  )
16
 
17
- # Load LoRA adapters (replace with your repo once pushed)
18
- model = PeftModel.from_pretrained(
19
- base_model,
20
- "lingadevaruhp/flirt-ai-gemma2-9b", # Update after pushing
21
- device_map="cpu"
22
- )
23
-
24
  def generate_response(prompt, max_new_tokens=50):
25
- inputs = tokenizer(prompt, return_tensors="pt")
26
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
27
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Gradio interface
30
  iface = gr.Interface(
31
  fn=generate_response,
32
- inputs=["text", gr.Slider(minimum=10, maximum=200, value=50, label="Max New Tokens")],
33
- outputs="text",
34
- title="Flirt-AI Gemma2-9B",
35
- description="Chat with a flirty AI powered by Gemma-2-9B!"
 
 
 
36
  )
37
 
38
  if __name__ == "__main__":
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import torch
3
  import gradio as gr
4
 
5
  # Load tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
7
 
8
+ # Load base model directly (no LoRA adapters)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
  "google/gemma-2-9b-it",
11
  torch_dtype=torch.bfloat16,
12
+ device_map="auto",
13
  low_cpu_mem_usage=True
14
  )
15
 
 
 
 
 
 
 
 
16
  def generate_response(prompt, max_new_tokens=50):
17
+ # Format the prompt for chat
18
+ formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
19
+
20
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
21
+
22
+ with torch.no_grad():
23
+ outputs = model.generate(
24
+ **inputs,
25
+ max_new_tokens=max_new_tokens,
26
+ do_sample=True,
27
+ temperature=0.7,
28
+ top_p=0.9,
29
+ pad_token_id=tokenizer.eos_token_id
30
+ )
31
+
32
+ # Decode only the generated part (excluding the input)
33
+ generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
34
+ return generated_text.strip()
35
 
36
  # Gradio interface
37
  iface = gr.Interface(
38
  fn=generate_response,
39
+ inputs=[
40
+ gr.Textbox(label="Your message", placeholder="Type your message here..."),
41
+ gr.Slider(minimum=10, maximum=200, value=50, label="Max New Tokens")
42
+ ],
43
+ outputs=gr.Textbox(label="AI Response"),
44
+ title="Flirt-AI Gemma2-9B (Base Model)",
45
+ description="Chat with AI powered by the base Gemma-2-9B model!"
46
  )
47
 
48
  if __name__ == "__main__":