| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| class LlamaAddressCompletion: |
| def __init__(self): |
| self.model_name = "shiprocket-ai/open-llama-1b-address-completion" |
| self.model = None |
| self.tokenizer = None |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.load_model() |
| |
| def load_model(self): |
| """Load the Llama model and tokenizer""" |
| try: |
| print("Loading Llama 3.2-1B Address Completion model...") |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True |
| ) |
| |
| if not torch.cuda.is_available(): |
| self.model = self.model.to(self.device) |
| |
| self.model.eval() |
| print("✅ Model loaded successfully!") |
| |
| except Exception as e: |
| print(f"❌ Error loading model: {str(e)}") |
| raise e |
| |
| def extract_address_components(self, address, max_new_tokens=150): |
| """Extract address components using the model""" |
| if not address.strip(): |
| return "Please provide an address to extract components from." |
| |
| try: |
| |
| prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| |
| Extract address components from: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| |
| """ |
| |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| |
| |
| device = next(self.model.parameters()).device |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=0.1, |
| top_p=0.9, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| repetition_penalty=1.05 |
| ) |
| |
| |
| input_length = inputs['input_ids'].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| return response.strip() |
| |
| except Exception as e: |
| return f"Error processing address: {str(e)}" |
| |
| def complete_partial_address(self, partial_address, max_new_tokens=100): |
| """Complete a partial address""" |
| if not partial_address.strip(): |
| return "Please provide a partial address to complete." |
| |
| try: |
| |
| prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| |
| Complete this partial address: {partial_address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| |
| """ |
| |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| |
| |
| device = next(self.model.parameters()).device |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=0.2, |
| top_p=0.9, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| repetition_penalty=1.05 |
| ) |
| |
| |
| input_length = inputs['input_ids'].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| return response.strip() |
| |
| except Exception as e: |
| return f"Error completing address: {str(e)}" |
| |
| def standardize_address(self, address, max_new_tokens=150): |
| """Standardize an address format""" |
| if not address.strip(): |
| return "Please provide an address to standardize." |
| |
| try: |
| |
| prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| |
| Standardize this address into proper format: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| |
| """ |
| |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| |
| |
| device = next(self.model.parameters()).device |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=0.1, |
| top_p=0.9, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| repetition_penalty=1.05 |
| ) |
| |
| |
| input_length = inputs['input_ids'].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| |
| return response.strip() |
| |
| except Exception as e: |
| return f"Error standardizing address: {str(e)}" |
|
|
| |
| print("Initializing Llama Address Completion system...") |
| try: |
| llama_system = LlamaAddressCompletion() |
| print("System ready!") |
| except Exception as e: |
| print(f"Failed to initialize system: {e}") |
| llama_system = None |
|
|
| def extract_components_interface(address_text): |
| """Interface function for component extraction""" |
| if llama_system is None: |
| return "❌ Model not loaded. Please check the logs." |
| |
| result = llama_system.extract_address_components(address_text) |
| return f"**Input:** {address_text}\n\n**Extracted Components:**\n{result}" |
|
|
| def complete_address_interface(partial_address): |
| """Interface function for address completion""" |
| if llama_system is None: |
| return "❌ Model not loaded. Please check the logs." |
| |
| result = llama_system.complete_partial_address(partial_address) |
| return f"**Partial Address:** {partial_address}\n\n**Completed Address:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*" |
|
|
| def standardize_address_interface(address_text): |
| """Interface function for address standardization""" |
| if llama_system is None: |
| return "❌ Model not loaded. Please check the logs." |
| |
| result = llama_system.standardize_address(address_text) |
| return f"**Original:** {address_text}\n\n**Standardized:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*" |
|
|
| |
| sample_addresses = [ |
| "C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", |
| "Villa 141, Geown Oasis, V Kallahalli, Off Sarjapur, Bengaluru, Karnataka, 562125", |
| "E401 Supertech Icon Indrapam 201301 UP", |
| "Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", |
| "Flat 201, MG Road, Bangalore, Karnataka, 560001" |
| ] |
|
|
| partial_addresses = [ |
| "C-704, Gayatri Shivam, Thakur Complex", |
| "Villa 141, Geown Oasis, V Kallahalli", |
| "E401 Supertech Icon", |
| "Shop No 123, Sunshine Apartments", |
| "Flat 201, MG Road, Bangalore" |
| ] |
|
|
| informal_addresses = [ |
| "c704 gayatri shivam thakur complex kandivali e 400101", |
| "villa141 geown oasis vkallahalli off sarjapur blr kar 562125", |
| "e401 supertech icon indrapam up 201301", |
| "shop123 sunshine apts andheri w mumbai 400058" |
| ] |
|
|
| |
| with gr.Blocks(title="Llama Address Intelligence", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # 🦙 Llama 3.2-1B Address Intelligence |
| |
| Powered by a fine-tuned Llama 3.2-1B model specialized for Indian address processing. |
| |
| **⭐ Best Performance**: Entity extraction from complete addresses |
| **⚠️ Limited Performance**: Address completion and standardization (limited training data) |
| |
| **Model:** [shiprocket-ai/open-llama-1b-address-completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) |
| """) |
| |
| with gr.Tab("📋 Extract Components"): |
| gr.Markdown("⭐ **BEST PERFORMANCE** - Extract structured components from complete addresses") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| extract_input = gr.Textbox( |
| label="Enter Address", |
| placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", |
| lines=3 |
| ) |
| extract_btn = gr.Button("🔍 Extract Components", variant="primary") |
| |
| gr.Markdown("### Sample Addresses:") |
| extract_samples = [] |
| for addr in sample_addresses: |
| btn = gr.Button(addr, size="sm") |
| btn.click(fn=lambda x=addr: x, outputs=extract_input) |
| extract_samples.append(btn) |
| |
| with gr.Column(scale=1): |
| extract_output = gr.Markdown( |
| value="Enter an address and click 'Extract Components' to see structured breakdown." |
| ) |
| |
| extract_btn.click( |
| fn=extract_components_interface, |
| inputs=extract_input, |
| outputs=extract_output |
| ) |
| |
| extract_input.submit( |
| fn=extract_components_interface, |
| inputs=extract_input, |
| outputs=extract_output |
| ) |
| |
| with gr.Tab("✨ Complete Address"): |
| gr.Markdown("⚠️ **EXPERIMENTAL** - Complete partial addresses (limited training data - results may vary)") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| complete_input = gr.Textbox( |
| label="Enter Partial Address", |
| placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex", |
| lines=3 |
| ) |
| complete_btn = gr.Button("🚀 Complete Address", variant="primary") |
| |
| gr.Markdown("### Sample Partial Addresses:") |
| complete_samples = [] |
| for addr in partial_addresses: |
| btn = gr.Button(addr, size="sm") |
| btn.click(fn=lambda x=addr: x, outputs=complete_input) |
| complete_samples.append(btn) |
| |
| with gr.Column(scale=1): |
| complete_output = gr.Markdown( |
| value="Enter a partial address and click 'Complete Address' to see the AI completion." |
| ) |
| |
| complete_btn.click( |
| fn=complete_address_interface, |
| inputs=complete_input, |
| outputs=complete_output |
| ) |
| |
| complete_input.submit( |
| fn=complete_address_interface, |
| inputs=complete_input, |
| outputs=complete_output |
| ) |
| |
| with gr.Tab("📐 Standardize Format"): |
| gr.Markdown("⚠️ **EXPERIMENTAL** - Convert informal addresses to standardized format (limited training data - results may vary)") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| standardize_input = gr.Textbox( |
| label="Enter Informal Address", |
| placeholder="e.g., c704 gayatri shivam thakur complex kandivali e 400101", |
| lines=3 |
| ) |
| standardize_btn = gr.Button("📏 Standardize Format", variant="primary") |
| |
| gr.Markdown("### Sample Informal Addresses:") |
| standardize_samples = [] |
| for addr in informal_addresses: |
| btn = gr.Button(addr, size="sm") |
| btn.click(fn=lambda x=addr: x, outputs=standardize_input) |
| standardize_samples.append(btn) |
| |
| with gr.Column(scale=1): |
| standardize_output = gr.Markdown( |
| value="Enter an informal address and click 'Standardize Format' to see the cleaned version." |
| ) |
| |
| standardize_btn.click( |
| fn=standardize_address_interface, |
| inputs=standardize_input, |
| outputs=standardize_output |
| ) |
| |
| standardize_input.submit( |
| fn=standardize_address_interface, |
| inputs=standardize_input, |
| outputs=standardize_output |
| ) |
| |
| with gr.Tab("ℹ️ Model Information"): |
| gr.Markdown(""" |
| ## 🦙 About Llama 3.2-1B Address Completion |
| |
| ### Model Specifications |
| - **Base Model**: meta-llama/Llama-3.2-1B-Instruct |
| - **Parameters**: 1.24B parameters |
| - **Model Size**: ~2.47GB |
| - **Architecture**: Causal Language Model (Autoregressive) |
| - **Max Context**: 131,072 tokens |
| - **Precision**: FP16 for GPU, FP32 for CPU |
| |
| ### Key Features |
| - **Lightweight**: Only 1B parameters for fast inference |
| - **Specialized**: Fine-tuned specifically for Indian addresses |
| - **Versatile**: Handles extraction, completion, and standardization |
| - **Efficient**: Optimized for real-time applications |
| - **Context-Aware**: Understands relationships between address components |
| |
| ### Supported Address Components |
| - **Building Names**: Apartments, complexes, towers, malls |
| - **Localities**: Areas, neighborhoods, sectors |
| - **Pincodes**: 6-digit Indian postal codes |
| - **Cities**: Major and minor Indian cities |
| - **States**: All Indian states and union territories |
| - **Sub-localities**: Sectors, phases, blocks |
| - **Road Names**: Streets, lanes, main roads |
| - **Landmarks**: Notable reference points |
| |
| ### Performance Notes |
| - **⭐ Entity Extraction**: Excellent performance - primary use case |
| - **⚠️ Address Completion**: Limited training data - experimental feature |
| - **⚠️ Format Standardization**: Limited training data - experimental feature |
| |
| **Recommendation**: Use this model primarily for address component extraction. |
| |
| ### Use Cases |
| - **E-commerce**: Auto-complete checkout addresses |
| - **Forms**: Intelligent address suggestions |
| - **Data Cleaning**: Standardize legacy address databases |
| - **Mobile Apps**: On-device address processing |
| - **APIs**: Real-time address validation services |
| |
| ### Performance Tips |
| - Use lower temperatures (0.1-0.3) for factual outputs |
| - Keep prompts under 512 tokens for optimal speed |
| - Process in batches for high-throughput scenarios |
| - Works best with Llama chat format prompts |
| """) |
| |
| gr.Markdown(""" |
| --- |
| **Powered by:** [Llama 3.2-1B Address Completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) | |
| **License:** Apache 2.0 | |
| **Developed by:** Shiprocket AI Team |
| |
| This model demonstrates the power of lightweight LLMs for specialized address intelligence tasks. |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch() |