Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import BitsAndBytesConfig | |
| import torch | |
| # ✅ Load the model and tokenizer | |
| MODEL_ID = "pareshmishra/mt564-gemma-lora" | |
| API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not API_TOKEN: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set") | |
| # Configure 4-bit quantization | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, # Enable 4-bit quantization | |
| bnb_4bit_compute_dtype=torch.float16, # Use fp16 for computation | |
| bnb_4bit_quant_type="nf4", # Normal Float 4-bit quantization | |
| bnb_4bit_use_double_quant=True # Nested quantization for efficiency | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=API_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| token=API_TOKEN, | |
| torch_dtype=torch.float16, # fp16 as per model card | |
| device_map="auto", # Auto-map to GPU/CPU | |
| quantization_config=quantization_config # Use BitsAndBytesConfig | |
| ) | |
| def respond(messages, chatbot_history, system_message, max_tokens, temperature, top_p): | |
| try: | |
| # Build prompt from history | |
| prompt = f"{system_message.strip()}\n\n" | |
| for msg in messages: | |
| if isinstance(msg, dict): | |
| role = msg.get("role") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| prompt += f"User: {content.strip()}\n" | |
| elif role == "assistant": | |
| prompt += f"Assistant: {content.strip()}\n" | |
| prompt += "Assistant:" | |
| # Tokenize and generate | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response[len(prompt):].strip() | |
| yield response if response else "⚠️ No response returned from the model." | |
| except Exception as e: | |
| yield f"❌ Error: {str(e)}\nDetails: {e.__class__.__name__}" | |
| # Gradio Interface | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| type="messages", | |
| additional_inputs=[ | |
| gr.Textbox( | |
| lines=3, | |
| label="System message", | |
| value="You are an expert in SWIFT MT564 financial messaging. Analyze, validate, and answer related user questions.", | |
| ), | |
| gr.Slider(50, 2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p sampling"), | |
| ], | |
| title="💬 MT564 Chat Assistant", | |
| description="Analyze SWIFT MT564 messages or ask financial-related questions.", | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |