Asit03 commited on
Commit
3ca2a07
·
verified ·
1 Parent(s): 0356e9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -63,19 +63,29 @@
63
  # if __name__ == "__main__":
64
  # demo.launch()
65
 
66
- import gradio as gr
67
- from transformers import AutoModelForCausalLM, AutoTokenizer
68
  import torch
 
 
69
 
 
70
 
71
- model_name = "Asit03/AI_Agent_V2_Merged"
72
- tokenizer = AutoTokenizer.from_pretrained(model_name)
73
- model = AutoModelForCausalLM.from_pretrained(model_name)
74
 
75
- def generate_response(prompt):
 
 
 
 
 
 
 
 
 
 
76
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
77
- outputs = model.generate(**inputs, max_new_tokens=1000)
78
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
79
 
80
- gr.Interface(fn=generate_response, inputs="text", outputs="text").launch()
81
-
 
63
  # if __name__ == "__main__":
64
  # demo.launch()
65
 
 
 
66
  import torch
67
+ from transformers import AutoTokenizer, AutoModelForCausalLM
68
+ import gradio as gr
69
 
70
+ model_id = "Asit03/AI_Agent_V2_Merged"
71
 
72
+ # Load tokenizer
73
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
74
 
75
+ # Load model with 4-bit quantization
76
+ model = AutoModelForCausalLM.from_pretrained(
77
+ model_id,
78
+ device_map="auto",
79
+ load_in_4bit=True,
80
+ torch_dtype=torch.bfloat16, # fallback to torch.float16 if needed
81
+ trust_remote_code=True
82
+ )
83
+
84
+ # Generation function
85
+ def chat(prompt):
86
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
87
+ outputs = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id)
88
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
89
 
90
+ # Launch Gradio app
91
+ gr.Interface(fn=chat, inputs="text", outputs="text", title="💬 AI Agent V2").launch()