lingadevaruhp commited on
Commit
2e861b0
·
verified ·
1 Parent(s): e3669ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -1,25 +1,39 @@
1
- import gradio as gr
2
- from unsloth import FastLanguageModel
3
  import torch
 
4
 
5
- model, tokenizer = FastLanguageModel.from_pretrained(
6
- "lingadevaruhp/flirt-ai-gemma2-2b",
7
- max_seq_length=2048,
8
- dtype=torch.float16,
9
- load_in_4bit=True,
10
- device_map="auto"
 
 
 
11
  )
12
- FastLanguageModel.for_inference(model)
13
 
14
- def chat(prompt):
15
- inputs = tokenizer(f"<s>### Instruction:\n{prompt}\n### Response:\n", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
16
- outputs = model.generate(**inputs, max_new_tokens=100)
 
 
 
 
 
 
 
17
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
19
- gr.Interface(
20
- fn=chat,
21
- inputs="text",
 
22
  outputs="text",
23
- title="Flirt.AI: Kannada Flirty Chatbot",
24
- description="Chat with a flirty AI using Kannada slang and English! 😎"
25
- ).launch()
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from peft import PeftModel
3
  import torch
4
+ import gradio as gr
5
 
6
+ # Load tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
8
+
9
+ # Load base model on CPU
10
+ base_model = AutoModelForCausalLM.from_pretrained(
11
+ "google/gemma-2-9b-it",
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="cpu",
14
+ low_cpu_mem_usage=True
15
  )
 
16
 
17
+ # Load LoRA adapters (replace with your repo once pushed)
18
+ model = PeftModel.from_pretrained(
19
+ base_model,
20
+ "lingadevaruhp/flirt-ai-gemma2-9b", # Update after pushing
21
+ device_map="cpu"
22
+ )
23
+
24
+ def generate_response(prompt, max_new_tokens=50):
25
+ inputs = tokenizer(prompt, return_tensors="pt")
26
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
27
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
+ # Gradio interface
30
+ iface = gr.Interface(
31
+ fn=generate_response,
32
+ inputs=["text", gr.Slider(minimum=10, maximum=200, value=50, label="Max New Tokens")],
33
  outputs="text",
34
+ title="Flirt-AI Gemma2-9B",
35
+ description="Chat with a flirty AI powered by Gemma-2-9B!"
36
+ )
37
+
38
+ if __name__ == "__main__":
39
+ iface.launch()