File size: 3,523 Bytes
289e125
 
 
 
 
 
 
 
 
 
 
 
 
 
7912f7c
289e125
 
 
ae68c70
 
289e125
 
 
 
 
 
 
 
c0705f0
 
 
 
 
 
 
 
 
 
 
 
 
 
289e125
 
c0705f0
 
289e125
 
 
 
 
c0705f0
289e125
 
ae68c70
 
 
289e125
f5dcbeb
289e125
f5dcbeb
 
 
 
289e125
f5dcbeb
 
 
 
 
 
289e125
f5dcbeb
289e125
 
 
 
f5dcbeb
289e125
 
 
 
 
c0705f0
289e125
f5dcbeb
 
289e125
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from fastapi import FastAPI
from dotenv import load_dotenv
from huggingface_hub.inference._mcp.agent import Agent
import gradio as gr
import uvicorn
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional, Literal

load_dotenv()
HF_TOKEN=os.getenv("HF_TOKEN")
HF_MODEL=os.getenv("HF_MODEL","Qwen/Qwen1.5-0.5B-Chat")
app=FastAPI(title="MODEL-CARD-CHATBOT")
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])

agent_instance: Optional[Agent]=None
DEFAULT_PROVIDER:Literal['hf-inference']="hf-inference"


async def get_agent():
    global agent_instance
    if agent_instance is None and HF_TOKEN:
        print("🔧 Creating new Agent instance ...")
        print(f"✅ HF_TOKEN present : {bool(HF_TOKEN)}")
        print(f"🤖 Model: {HF_MODEL}")
        print(f"Provider: {DEFAULT_PROVIDER}")
        try:
            agent = Agent(
                model=HF_MODEL,
                provider="hf-inference",
                api_key=HF_TOKEN,
                servers=[{
                    "type": "stdio",
                    "config": {
                        "command": "python",
                        "args": ["mcp_server.py"],
                        "cwd": ".",
                        "env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {}
                    }
                }]
            )
            print("🚀 Agent instance created successfully")
            print("🔁 loading tools ...")
            await agent.load_tools()
            agent_instance = agent
            print("✅ Tools loaded successfully")
        except Exception as e:
            print(f"❌ Error creating/loading agent: {str(e)}")
    return agent_instance


@app.on_event("startup")
async def startup_event():
    global agent_instance
    agent_instance = await get_agent()


async def chat_function(user_message, history, model_id):
    prompt=f"""You're an assistant helping with hugging face model cards.
First, run the tool `read_model_card` on repo_id `{model_id}` to get the model card.
Then answer this user question based on the model card:
User question: {user_message}"""
    history = history + [(user_message, None)]
    try:
        response = ""
        async for output in agent_instance.run(prompt):
            if hasattr(output, "content") and output.content:
                response = output.content
        final_response = response or "⚠️ Sorry, I couldn't generate a response."
        history[-1] = (user_message, final_response)
    except Exception as e:
        history[-1] = (user_message, f"⚠️ Error: {str(e)}")
    return history, ""




def create_gradio_app():
    with gr.Blocks(title="Model Card Chatbot") as demo:
        gr.Markdown("## 🤖 Model Card Chatbot\nAsk questions about Hugging Face model card")
        with gr.Row():
            model_id=gr.Textbox(label="MODEL ID", value="google/gemma-2-2b")
            user_input=gr.Textbox(label="Your Question",value="Ask something about the model card .....")
            send=gr.Button("Ask")
            chatbot=gr.Chatbot(label="chat")
            send.click(fn=chat_function,inputs=[user_input,chatbot,model_id],outputs=[chatbot,user_input])
            return demo
gradio_app=create_gradio_app()
app=gr.mount_gradio_app(app,gradio_app,path="/")

@app.get("/")
async def root():
    return RedirectResponse("/")
if __name__=="__main__":
    uvicorn.run("app:app",host="0.0.0.0",port=7860,reload=True)