Sandhya
first commit
707ec96
raw
history blame
3.6 kB
# app.py
import os
from fastapi import FastAPI
from dotenv import load_dotenv
from huggingface_hub.inference._mcp.agent import Agent
import gradio as gr
import uvicorn
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, Literal
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL", "google/gemma-2-2b")
app = FastAPI(title="Model Card Chatbot")
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)
agent_instance: Optional[Agent] = None
DEFAULT_PROVIDER: Literal['hf-inference'] = "hf-inference"
async def get_agent():
global agent_instance
if agent_instance is None and HF_TOKEN:
print("🔧 Creating new Agent instance ...")
print(f"✅ HF_TOKEN present: {bool(HF_TOKEN)}")
print(f"🤖 Model: {HF_MODEL}")
try:
agent = Agent(
model=HF_MODEL,
provider=DEFAULT_PROVIDER,
api_key=HF_TOKEN,
servers=[{
"type": "stdio",
"config": {
"command": "python",
"args": ["mcp_server.py"],
"cwd": ".",
"env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {}
}
}]
)
await agent.load_tools()
agent_instance = agent
print("✅ Agent is ready")
except Exception as e:
print(f"❌ Error creating/loading agent: {str(e)}")
return agent_instance
@app.on_event("startup")
async def startup_event():
global agent_instance
agent_instance = await get_agent()
def chat_function(user_message, history, model_id):
global agent_instance
prompt = f"""
You're an assistant helping with Hugging Face model cards.
First, run the tool `read_model_card` on repo_id `{model_id}` to get the model card.
Then answer this user question based on the model card:
User question: {user_message}
"""
history = history + [(user_message, None)]
try:
response = ""
outputs = agent_instance.run(prompt)
for output in outputs:
if hasattr(output, "content") and output.content:
response = output.content
if not response:
response = "⚠️ Sorry, I couldn't generate a response."
history[-1] = (user_message, response)
except Exception as e:
history[-1] = (user_message, f"⚠️ Error: {str(e)}")
return history, ""
def create_gradio_app():
with gr.Blocks(theme=gr.themes.Soft(), title="🤖 Model Card Chatbot") as demo:
gr.Markdown("""
# 🤖 **Model Card Chatbot**
Ask anything about a model's card on Hugging Face.
""")
with gr.Row():
model_id = gr.Textbox(label="Model ID", value="google/gemma-2-2b", scale=2)
user_input = gr.Textbox(label="Your Question", placeholder="e.g., What is this model trained on?", scale=3)
send = gr.Button("🔍 Ask", scale=1)
chatbot = gr.Chatbot(label="Chat")
send.click(
fn=chat_function,
inputs=[user_input, chatbot, model_id],
outputs=[chatbot, user_input]
)
return demo
gradio_app = create_gradio_app()
app = gr.mount_gradio_app(app, gradio_app, path="/")
@app.get("/")
async def root():
return RedirectResponse("/")
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)