pareshmishra commited on
Commit
47d01e6
·
verified ·
1 Parent(s): 880dce5

Create app_1.py

Browse files

file with transformer

Files changed (1) hide show
  1. app_1.py +83 -0
app_1.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from transformers import BitsAndBytesConfig
5
+ import torch
6
+
7
+ # ✅ Load the model and tokenizer
8
+ MODEL_ID = "pareshmishra/mt564-gemma-lora"
9
+ API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
10
+ if not API_TOKEN:
11
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set")
12
+
13
+ # Configure 4-bit quantization
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_4bit=True, # Enable 4-bit quantization
16
+ bnb_4bit_compute_dtype=torch.float16, # Use fp16 for computation
17
+ bnb_4bit_quant_type="nf4", # Normal Float 4-bit quantization
18
+ bnb_4bit_use_double_quant=True # Nested quantization for better efficiency
19
+ )
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_ID,
24
+ token=API_TOKEN,
25
+ torch_dtype=torch.float16, # fp16 as per model card
26
+ device_map="auto", # Auto-map to GPU/CPU
27
+ quantization_config=quantization_config # Use BitsAndBytesConfig
28
+ )
29
+
30
+ def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p):
31
+ try:
32
+ # Build prompt from history
33
+ prompt = f"{system_message.strip()}\n\n"
34
+ for msg in messages:
35
+ if isinstance(msg, dict):
36
+ role = msg.get("role")
37
+ content = msg.get("content", "")
38
+ if role == "user":
39
+ prompt += f"User: {content.strip()}\n"
40
+ elif role == "assistant":
41
+ prompt += f"Assistant: {content.strip()}\n"
42
+ prompt += "Assistant:"
43
+
44
+ # Tokenize and generate
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+ outputs = model.generate(
47
+ **inputs,
48
+ max_new_tokens=max_tokens,
49
+ temperature=temperature,
50
+ top_p=top_p,
51
+ do_sample=True,
52
+ pad_token_id=tokenizer.eos_token_id
53
+ )
54
+
55
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ response = response[len(prompt):].strip()
57
+
58
+ yield response if response else "⚠️ No response returned from the model."
59
+
60
+ except Exception as e:
61
+ yield f"❌ Error: {str(e)}\nDetails: {e.__class__.__name__}"
62
+
63
+ # Gradio Interface
64
+ demo = gr.ChatInterface(
65
+ fn=respond,
66
+ type="messages",
67
+ additional_inputs=[
68
+ gr.Textbox(
69
+ lines=3,
70
+ label="System message",
71
+ value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.",
72
+ ),
73
+ gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"),
74
+ gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
75
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"),
76
+ ],
77
+ title="💬 MT564 Chat Assistant",
78
+ description="Analyze SWIFT MT564 messages or ask financial-related questions.",
79
+ theme="default"
80
+ )
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch()