File size: 3,245 Bytes
d3bf1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
"""
Example client for interacting with the Mistral 7B AHB2APB API
"""

import requests
import json

# API configuration
API_BASE_URL = "http://localhost:8000"

def check_api_health():
    """Check if the API is running and healthy"""
    try:
        response = requests.get(f"{API_BASE_URL}/health")
        response.raise_for_status()
        health = response.json()
        print("✓ API is healthy!")
        print(f"  Model: {health['model_path']}")
        print(f"  Device: {health['device']}")
        print(f"  Model loaded: {health['model_loaded']}")
        return True
    except requests.exceptions.ConnectionError:
        print("✗ Cannot connect to API. Is the server running?")
        print(f"  Start it with: python api_server.py")
        return False
    except Exception as e:
        print(f"✗ Error checking API health: {e}")
        return False

def generate(prompt: str, max_length: int = 512, temperature: float = 0.7):
    """Generate text using the API"""
    try:
        response = requests.post(
            f"{API_BASE_URL}/api/generate",
            json={
                "prompt": prompt,
                "max_length": max_length,
                "temperature": temperature
            },
            timeout=120
        )
        response.raise_for_status()
        result = response.json()
        return result['response']
    except requests.exceptions.RequestException as e:
        print(f"Error calling API: {e}")
        if hasattr(e.response, 'text'):
            print(f"Response: {e.response.text}")
        return None

def generate_batch(prompts: list, max_length: int = 512, temperature: float = 0.7):
    """Generate text for multiple prompts in batch"""
    try:
        requests_data = [
            {
                "prompt": prompt,
                "max_length": max_length,
                "temperature": temperature
            }
            for prompt in prompts
        ]
        
        response = requests.post(
            f"{API_BASE_URL}/api/generate/batch",
            json=requests_data,
            timeout=300  # Longer timeout for batch
        )
        response.raise_for_status()
        result = response.json()
        return [item['response'] for item in result['results']]
    except requests.exceptions.RequestException as e:
        print(f"Error calling batch API: {e}")
        if hasattr(e.response, 'text'):
            print(f"Response: {e.response.text}")
        return None

def main():
    """Example usage"""
    print("=" * 70)
    print("Mistral 7B AHB2APB API Client Example")
    print("=" * 70)
    print()
    
    # Check health
    if not check_api_health():
        return
    
    print()
    print("=" * 70)
    print("Generating Response")
    print("=" * 70)
    print()
    
    # Example prompt for AHB to APB conversion
    prompt = "Convert this AHB burst to APB"
    
    print(f"Prompt: {prompt}")
    print()
    print("Response:")
    print("-" * 70)
    
    response = generate(prompt, max_length=512, temperature=0.7)
    
    if response:
        print(response)
        print("-" * 70)
    else:
        print("Failed to generate response")
    
    print()

if __name__ == "__main__":
    main()