from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import gradio as gr from huggingface_hub import login import os login(token=os.getenv("HF_TOKEN")) from transformers import AutoTokenizer, AutoModelForCausalLM import torch model_id = "ranggafermata/Fermata-v1.2-light" tokenizer = AutoTokenizer.from_pretrained(model_id, attn_implementation="eager") model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32) app = FastAPI() def chat_function(message): inputs = tokenizer(message, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=128) return tokenizer.decode(outputs[0], skip_special_tokens=True) @app.post("/chat") async def chat_api(request: Request): try: body = await request.json() prompt = body.get("input", "") if not prompt: return JSONResponse(content={"error": "Missing input"}, status_code=400) output = chat_function(prompt) return JSONResponse(content={"output": output}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # Weather API def get_weather(location): key = os.getenv("OPENWEATHER_API_KEY") if not key: return "Missing API key for weather." try: url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={key}&units=metric" r = requests.get(url).json() return f"{r['name']}: {r['main']['temp']}°C, {r['weather'][0]['description']}" except: return "Failed to fetch weather." # NASA API def get_apod(): key = os.getenv("NASA_API_KEY") if not key: return "Missing API key for NASA." try: r = requests.get(f"https://api.nasa.gov/planetary/apod?api_key={key}").json() return f"📷 {r['title']}\n\n{r['explanation']}\n\nMedia: {r['url']}" except: return "Failed to fetch NASA APOD." # Parse tool call JSON inside [TOOL_CALL] {...} def parse_tool_call(output): if not output or "[TOOL_CALL]" not in output: return None match = re.search(r"\[TOOL_CALL\]\s*(\{.*?\})", output, re.DOTALL) if not match: return None json_str = match.group(1).strip() if not json_str or json_str in ["null", "None", ""]: return None try: return json.loads(json_str) except json.JSONDecodeError as e: print(f"❌ JSON parsing failed: {e}") print(f"⚠️ Bad JSON string: {json_str}") return None # Chat logic def respond(message): prompt = f"### Human:\n{message}\n\n### Assistant:" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.7, do_sample=True) result = tokenizer.decode(outputs[0], skip_special_tokens=True) reply = result.split("### Assistant:")[-1].strip() tool = parse_tool_call(reply) if tool: if tool["tool"] == "get_weather": return get_weather(tool.get("location", "Unknown")) elif tool["tool"] == "get_apod": return get_apod() else: return f"Tool not recognized: {tool['tool']}" return reply # UI gr.Interface( fn=respond, inputs=gr.Textbox(lines=2, placeholder="Ask me something..."), outputs="text", title="Fermata AI 1.2", description="Now powered by the official Gemma 3 model. Ask about the weather or NASA's daily space image!", ).launch()