ranggafermata's picture
Update app.py
b2c9400 verified
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import gradio as gr
from huggingface_hub import login
import os
login(token=os.getenv("HF_TOKEN"))
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "ranggafermata/Fermata-v1.2-light"
tokenizer = AutoTokenizer.from_pretrained(model_id, attn_implementation="eager")
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
app = FastAPI()
def chat_function(message):
inputs = tokenizer(message, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@app.post("/chat")
async def chat_api(request: Request):
try:
body = await request.json()
prompt = body.get("input", "")
if not prompt:
return JSONResponse(content={"error": "Missing input"}, status_code=400)
output = chat_function(prompt)
return JSONResponse(content={"output": output})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Weather API
def get_weather(location):
key = os.getenv("OPENWEATHER_API_KEY")
if not key:
return "Missing API key for weather."
try:
url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={key}&units=metric"
r = requests.get(url).json()
return f"{r['name']}: {r['main']['temp']}Β°C, {r['weather'][0]['description']}"
except:
return "Failed to fetch weather."
# NASA API
def get_apod():
key = os.getenv("NASA_API_KEY")
if not key:
return "Missing API key for NASA."
try:
r = requests.get(f"https://api.nasa.gov/planetary/apod?api_key={key}").json()
return f"πŸ“· {r['title']}\n\n{r['explanation']}\n\nMedia: {r['url']}"
except:
return "Failed to fetch NASA APOD."
# Parse tool call JSON inside [TOOL_CALL] {...}
def parse_tool_call(output):
if not output or "[TOOL_CALL]" not in output:
return None
match = re.search(r"\[TOOL_CALL\]\s*(\{.*?\})", output, re.DOTALL)
if not match:
return None
json_str = match.group(1).strip()
if not json_str or json_str in ["null", "None", ""]:
return None
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"❌ JSON parsing failed: {e}")
print(f"⚠️ Bad JSON string: {json_str}")
return None
# Chat logic
def respond(message):
prompt = f"### Human:\n{message}\n\n### Assistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.7, do_sample=True)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
reply = result.split("### Assistant:")[-1].strip()
tool = parse_tool_call(reply)
if tool:
if tool["tool"] == "get_weather":
return get_weather(tool.get("location", "Unknown"))
elif tool["tool"] == "get_apod":
return get_apod()
else:
return f"Tool not recognized: {tool['tool']}"
return reply
# UI
gr.Interface(
fn=respond,
inputs=gr.Textbox(lines=2, placeholder="Ask me something..."),
outputs="text",
title="Fermata AI 1.2",
description="Now powered by the official Gemma 3 model. Ask about the weather or NASA's daily space image!",
).launch()