| import os |
| import json |
| from typing import Dict, Any, Optional |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| import gradio as gr |
|
|
| app = FastAPI(title="ZenoBot Travel API") |
|
|
| |
| MODEL_ID = os.environ.get("MODEL_ID", "meta-llama/Llama-3.2-3B") |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| MAX_MODEL_CONTEXT = 8192 |
| DEFAULT_MAX_TOKENS = 4096 |
|
|
| |
| print(f"Loading model {MODEL_ID} on {DEVICE}...") |
| try: |
| |
| tokenizer = AutoTokenizer.from_pretrained("./") |
| print("Loaded tokenizer from local directory") |
| except Exception as e: |
| print(f"Couldn't load from local directory: {e}") |
| print("Attempting to load from model hub...") |
| |
| hf_token = os.environ.get("HF_TOKENIZER_READ_TOKEN", None) |
| if hf_token: |
| print("Using token from environment variable") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token) |
| else: |
| print("No token found in environment, attempting without authentication") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
| print("Model loaded successfully!") |
|
|
| class TravelRequest(BaseModel): |
| query: str |
| temperature: Optional[float] = 0.1 |
| max_tokens: Optional[int] = DEFAULT_MAX_TOKENS |
|
|
| class TravelResponse(BaseModel): |
| response: Dict[str, Any] |
|
|
| |
| def load_system_prompt(): |
| try: |
| |
| with open("modelfile", "r") as f: |
| content = f.read() |
| start = content.find('SYSTEM """') + 9 |
| end = content.rfind('"""') |
| return content[start:end].strip() |
| except Exception as e: |
| print(f"Error loading from modelfile: {e}") |
| try: |
| |
| with open("system_prompt.txt", "r") as f: |
| return f.read().strip() |
| except Exception as e2: |
| print(f"Error loading from system_prompt.txt: {e2}") |
| |
| return """You are Zeno-Bot, a travel assistant specializing in creating detailed travel itineraries strictly within one state in a country.""" |
|
|
| SYSTEM_PROMPT = load_system_prompt() |
|
|
| @app.post("/generate", response_model=TravelResponse) |
| async def generate_travel_plan(request: TravelRequest): |
| try: |
| |
| if request.max_tokens > MAX_MODEL_CONTEXT - 1000: |
| print(f"Warning: Requested {request.max_tokens} tokens exceeds safe limit. Capping at {MAX_MODEL_CONTEXT - 1000}.") |
| request.max_tokens = MAX_MODEL_CONTEXT - 1000 |
| |
| |
| |
| prompt = f"""<|system|> |
| {SYSTEM_PROMPT} |
| <|user|> |
| {request.query} |
| <|assistant|>""" |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| output = model.generate( |
| inputs.input_ids, |
| max_new_tokens=request.max_tokens, |
| temperature=request.temperature, |
| do_sample=True, |
| ) |
| |
| response_text = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| |
| |
| try: |
| |
| json_start = response_text.find('{') |
| json_end = response_text.rfind('}') + 1 |
| |
| if json_start != -1 and json_end != -1: |
| json_str = response_text[json_start:json_end] |
| json_data = json.loads(json_str) |
| return {"response": json_data} |
| else: |
| |
| return {"response": {"error": "No valid JSON found", "raw_text": response_text}} |
| |
| except json.JSONDecodeError: |
| return {"response": {"error": "Invalid JSON format", "raw_text": response_text}} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy"} |
|
|
| def generate_itinerary(query, temperature=0.1, max_tokens=DEFAULT_MAX_TOKENS): |
| """Function for Gradio interface""" |
| try: |
| request = TravelRequest(query=query, temperature=temperature, max_tokens=max_tokens) |
| result = generate_travel_plan(request) |
| return json.dumps(result.response, indent=2) |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| demo = gr.Interface( |
| fn=generate_itinerary, |
| inputs=[ |
| gr.Textbox(lines=3, placeholder="Plan a 3-day trip to California starting on 15/04/2024", label="Travel Query"), |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"), |
| gr.Slider(minimum=512, maximum=6144, value=DEFAULT_MAX_TOKENS, step=512, label="Max Tokens") |
| ], |
| outputs=gr.JSON(label="Generated Itinerary"), |
| title="ZenoBot Travel Assistant", |
| description="Generate detailed travel itineraries within a single state. Example: 'Plan a 3-day trip to California starting on 15/04/2024'" |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| |
| port = int(os.environ.get("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|